From ec811683302f601e476a8df760549cd67fe0c389 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Fri, 5 Dec 2025 01:46:27 -0800 Subject: [PATCH 01/23] introduce model arch config Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 308 +++++++++++++++ tests/config/test_model_arch_config.py | 87 ++++ tests/test_config.py | 1 + vllm/config/model.py | 311 ++------------- vllm/config/model_arch.py | 63 +++ .../model_arch_config_convertor.py | 374 ++++++++++++++++++ 6 files changed, 874 insertions(+), 270 deletions(-) create mode 100644 tests/config/model_arch_groundtruth.json create mode 100644 tests/config/test_model_arch_config.py create mode 100644 vllm/config/model_arch.py create mode 100644 vllm/transformers_utils/model_arch_config_convertor.py diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json new file mode 100644 index 0000000000000..f8fabf4bd9ef1 --- /dev/null +++ b/tests/config/model_arch_groundtruth.json @@ -0,0 +1,308 @@ +{ + "Zyphra/Zamba2-7B-instruct": { + "architectures": [ + "Zamba2ForCausalLM" + ], + "model_type": "zamba2", + "text_model_type": "zamba2", + "hidden_size": 3584, + "total_num_hidden_layers": 81, + "total_num_attention_heads": 32, + "head_size": 224, + "vocab_size": 32000, + "total_num_kv_heads": 32, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "mosaicml/mpt-7b": { + "architectures": [ + "MPTForCausalLM" + ], + "model_type": "mpt", + "text_model_type": "mpt", + "hidden_size": 4096, + "total_num_hidden_layers": 32, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 50432, + "total_num_kv_heads": 32, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "databricks/dbrx-instruct": { + "architectures": [ + "DbrxForCausalLM" + ], + "model_type": "dbrx", + "text_model_type": "dbrx", + "hidden_size": 6144, + "total_num_hidden_layers": 40, + "total_num_attention_heads": 48, + "head_size": 128, + "vocab_size": 100352, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiiuae/falcon-7b": { + "architectures": [ + "FalconForCausalLM" + ], + "model_type": "falcon", + "text_model_type": "falcon", + "hidden_size": 4544, + "total_num_hidden_layers": 32, + "total_num_attention_heads": 71, + "head_size": 64, + "vocab_size": 65024, + "total_num_kv_heads": 1, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiiuae/falcon-40b": { + "architectures": [ + "FalconForCausalLM" + ], + "model_type": "falcon", + "text_model_type": "falcon", + "hidden_size": 8192, + "total_num_hidden_layers": 60, + "total_num_attention_heads": 128, + "head_size": 64, + "vocab_size": 65024, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "luccafong/deepseek_mtp_main_random": { + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "model_type": "deepseek_v3", + "text_model_type": "deepseek_v3", + "hidden_size": 2560, + "total_num_hidden_layers": 5, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "luccafong/deepseek_mtp_draft_random": { + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "model_type": "deepseek_v3", + "text_model_type": "deepseek_v3", + "hidden_size": 2560, + "total_num_hidden_layers": 10, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "Qwen/Qwen3-Next-80B-A3B-Instruct": { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "model_type": "qwen3_next", + "text_model_type": "qwen3_next", + "hidden_size": 2048, + "total_num_hidden_layers": 48, + "total_num_attention_heads": 16, + "head_size": 256, + "vocab_size": 151936, + "total_num_kv_heads": 2, + "num_experts": 512, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiny-random/qwen3-next-moe": { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "model_type": "qwen3_next", + "text_model_type": "qwen3_next", + "hidden_size": 8, + "total_num_hidden_layers": 4, + "total_num_attention_heads": 16, + "head_size": 32, + "vocab_size": 151936, + "total_num_kv_heads": 8, + "num_experts": 32, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "zai-org/GLM-4.5": { + "architectures": [ + "Glm4MoeForCausalLM" + ], + "model_type": "glm4_moe", + "text_model_type": "glm4_moe", + "hidden_size": 5120, + "total_num_hidden_layers": 92, + "total_num_attention_heads": 96, + "head_size": 128, + "vocab_size": 151552, + "total_num_kv_heads": 8, + "num_experts": 160, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "baidu/ERNIE-4.5-21B-A3B-PT": { + "architectures": [ + "Ernie4_5_MoeForCausalLM" + ], + "model_type": "ernie4_5_moe", + "text_model_type": "ernie4_5_moe", + "hidden_size": 2560, + "total_num_hidden_layers": 28, + "total_num_attention_heads": 20, + "head_size": 128, + "vocab_size": 103424, + "total_num_kv_heads": 4, + "num_experts": 64, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "lmsys/gpt-oss-20b-bf16": { + "architectures": [ + "GptOssForCausalLM" + ], + "model_type": "gpt_oss", + "text_model_type": "gpt_oss", + "hidden_size": 2880, + "total_num_hidden_layers": 24, + "total_num_attention_heads": 64, + "head_size": 64, + "vocab_size": 201088, + "total_num_kv_heads": 8, + "num_experts": 32, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "deepseek-ai/DeepSeek-V3.2-Exp": { + "architectures": [ + "DeepseekV32ForCausalLM" + ], + "model_type": "deepseek_v32", + "text_model_type": "deepseek_v32", + "hidden_size": 7168, + "total_num_hidden_layers": 61, + "total_num_attention_heads": 128, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 128, + "num_experts": 256, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "meta-llama/Llama-4-Scout-17B-16E-Instruct": { + "architectures": [ + "Llama4ForConditionalGeneration" + ], + "model_type": "llama4", + "text_model_type": "llama4_text", + "hidden_size": 5120, + "total_num_hidden_layers": 48, + "total_num_attention_heads": 40, + "head_size": 128, + "vocab_size": 202048, + "total_num_kv_heads": 8, + "num_experts": 16, + "is_deepseek_mla": false, + "is_multimodal_model": true, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "nvidia/Llama-3_3-Nemotron-Super-49B-v1": { + "architectures": [ + "DeciLMForCausalLM" + ], + "model_type": "nemotron-nas", + "text_model_type": "nemotron-nas", + "hidden_size": 8192, + "total_num_hidden_layers": 80, + "total_num_attention_heads": 64, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "XiaomiMiMo/MiMo-7B-RL": { + "architectures": [ + "MiMoForCausalLM" + ], + "model_type": "mimo", + "text_model_type": "mimo", + "hidden_size": 4096, + "total_num_hidden_layers": 36, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 151680, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "meituan-longcat/LongCat-Flash-Chat": { + "architectures": [ + "LongcatFlashForCausalLM" + ], + "model_type": "longcat_flash", + "text_model_type": "longcat_flash", + "hidden_size": 6144, + "total_num_hidden_layers": 28, + "total_num_attention_heads": 64, + "head_size": 576, + "vocab_size": 131072, + "total_num_kv_heads": 64, + "num_experts": 512, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.float32", + "dtype_original_type": "torch.dtype" + } +} diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py new file mode 100644 index 0000000000000..2900cd977efab --- /dev/null +++ b/tests/config/test_model_arch_config.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path + +import torch + +from vllm.config import ModelConfig + + +def test_model_arch_config(): + trust_remote_code_models = [ + "nvidia/Llama-3_3-Nemotron-Super-49B-v1", + "XiaomiMiMo/MiMo-7B-RL", + # Not available online right now + # "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", + "meituan-longcat/LongCat-Flash-Chat", + ] + models_to_test = [ + "Zyphra/Zamba2-7B-instruct", + "mosaicml/mpt-7b", + "databricks/dbrx-instruct", + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "luccafong/deepseek_mtp_main_random", + "luccafong/deepseek_mtp_draft_random", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "tiny-random/qwen3-next-moe", + "zai-org/GLM-4.5", + "baidu/ERNIE-4.5-21B-A3B-PT", + # Select some models using base convertor for testing + "lmsys/gpt-oss-20b-bf16", + "deepseek-ai/DeepSeek-V3.2-Exp", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + ] + trust_remote_code_models + + groundtruth_path = Path(__file__).parent / "model_arch_groundtruth.json" + with open(groundtruth_path) as f: + model_arch_groundtruth = json.load(f) + + for model in models_to_test: + print(f"testing {model=}") + model_config = ModelConfig( + model, trust_remote_code=model in trust_remote_code_models + ) + + model_arch_config = model_config.model_arch_config + expected = model_arch_groundtruth[model] + assert model_arch_config.architectures == expected["architectures"] + assert model_arch_config.model_type == expected["model_type"] + assert model_arch_config.text_model_type == expected["text_model_type"] + assert model_arch_config.hidden_size == expected["hidden_size"] + assert ( + model_arch_config.total_num_hidden_layers + == expected["total_num_hidden_layers"] + ) + assert ( + model_arch_config.total_num_attention_heads + == expected["total_num_attention_heads"] + ) + assert model_arch_config.head_size == expected["head_size"] + assert model_arch_config.vocab_size == expected["vocab_size"] + assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] + assert model_arch_config.num_experts == expected["num_experts"] + assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] + assert model_arch_config.is_multimodal_model == expected["is_multimodal_model"] + + dtype = model_arch_config.torch_dtype + assert str(dtype) == expected["dtype"] + if expected["dtype_original_type"] == "str": + assert isinstance(dtype, str) + elif expected["dtype_original_type"] == "torch.dtype": + assert isinstance(dtype, torch.dtype) + else: + raise ValueError(f"Unknown dtype_original_type: {expected['dtype']}") + + # Test that model_config methods return expected values + assert model_config.architectures == expected["architectures"] + assert model_config.get_vocab_size() == expected["vocab_size"] + assert model_config.get_hidden_size() == expected["hidden_size"] + assert model_config.get_head_size() == expected["head_size"] + assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] + assert model_config.get_num_experts() == expected["num_experts"] + assert ( + model_config.get_total_num_hidden_layers() + == expected["total_num_hidden_layers"] + ) diff --git a/tests/test_config.py b/tests/test_config.py index 77d3a7115978e..65807a17e9507 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import logging import os from dataclasses import MISSING, Field, asdict, dataclass, field diff --git a/vllm/config/model.py b/vllm/config/model.py index 764bdf7000561..7f6313005737b 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -10,15 +10,17 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass -from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers.configuration_utils import ALLOWED_LAYER_TYPES import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.model_arch import ( + ModelArchitectureConfig, +) from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import config, getattr_iter +from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -31,7 +33,6 @@ from vllm.transformers_utils.config import ( is_encoder_decoder, try_get_dense_modules, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope, uses_xdrope_dim, @@ -42,10 +43,13 @@ from vllm.transformers_utils.gguf_utils import ( maybe_patch_hf_config_from_gguf, split_remote_gguf, ) +from vllm.transformers_utils.model_arch_config_convertor import ( + MODEL_ARCH_CONFIG_CONVERTORS, + ModelArchConfigConvertorBase, +) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.import_utils import LazyLoader -from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig @@ -504,6 +508,12 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) + self.model_arch_config = None + convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( + hf_config.model_type, ModelArchConfigConvertorBase + ) + convertor = convertor_cls(hf_config) + self.model_arch_config = convertor.convert(self.model, self.revision) architectures = self.architectures registry = self.registry @@ -765,7 +775,7 @@ class ModelConfig: @property def architectures(self) -> list[str]: - return getattr(self.hf_config, "architectures", []) + return self.model_arch_config.architectures @property def architecture(self) -> str: @@ -934,50 +944,16 @@ class ModelConfig: return "embed" - def _parse_quant_hf_config(self, hf_config: PretrainedConfig): - quant_cfg = getattr(hf_config, "quantization_config", None) - if quant_cfg is None: - # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(hf_config, "compression_config", None) - - else: - # Set quant_method for ModelOpt models. - producer_name = quant_cfg.get("producer", {}).get("name") - if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") - - return quant_cfg - def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config(self.hf_config) - if quant_cfg is None and ( - text_config := getattr(self.hf_config, "text_config", None) - ): - # Check the text config as well for multi-modal models. - quant_cfg = self._parse_quant_hf_config(text_config) + quant_cfg = ModelArchConfigConvertorBase.get_quantization_config(self.hf_config) if quant_cfg is not None: - # Use the community standard 'quant_method' - quant_method = quant_cfg.get("quant_method", "").lower() - - # Normalize library names - quant_method = quant_method.replace( - "compressed_tensors", "compressed-tensors" - ) - - quant_cfg["quant_method"] = quant_method - + quant_method = quant_cfg["quant_method"] # Quantization methods which are overrides (i.e. they have a # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). @@ -1059,7 +1035,7 @@ class ModelConfig: logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " "to eager mode.", - self.hf_config.model_type, + self.model_arch_config.model_type, ) self.enforce_eager = True @@ -1070,11 +1046,9 @@ class ModelConfig: # TODO Remove this when bitsandbytes supports. """ is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = ( - getattr(self.hf_config, "quantization_config", None) is not None - ) + has_quantization_config = self.model_arch_config.quantization_config is not None is_8bit = ( - self.hf_config.quantization_config.get("load_in_8bit", False) + self.model_arch_config.quantization_config.get("load_in_8bit", False) if has_quantization_config else False ) @@ -1128,9 +1102,7 @@ class ModelConfig: self, parallel_config: ParallelConfig, ) -> None: - total_num_attention_heads = getattr( - self.hf_text_config, "num_attention_heads", 0 - ) + total_num_attention_heads = self.model_arch_config.total_num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( @@ -1181,10 +1153,10 @@ class ModelConfig: return getattr(self.hf_text_config, "sliding_window", None) def get_vocab_size(self) -> int: - return getattr(self.hf_text_config, "vocab_size", 0) + return self.model_arch_config.vocab_size def get_hidden_size(self) -> int: - return getattr(self.hf_text_config, "hidden_size", 0) + return self.model_arch_config.hidden_size def get_inputs_embeds_size(self) -> int: # The size of inputs_embeds is usually identical to the size @@ -1198,29 +1170,7 @@ class ModelConfig: @property def is_deepseek_mla(self) -> bool: - if not hasattr(self.hf_text_config, "model_type"): - return False - elif self.hf_text_config.model_type in ( - "deepseek_v2", - "deepseek_v3", - "deepseek_v32", - "deepseek_mtp", - "kimi_k2", - "kimi_linear", - "longcat_flash", - "pangu_ultra_moe", - "pangu_ultra_moe_mtp", - ): - return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == "eagle": - # if the model is an EAGLE module, check for the - # underlying architecture - return ( - self.hf_text_config.model.model_type - in ("deepseek_v2", "deepseek_v3", "deepseek_v32") - and self.hf_text_config.kv_lora_rank is not None - ) - return False + return self.model_arch_config.is_deepseek_mla @cached_property def is_mm_prefix_lm(self) -> bool: @@ -1236,97 +1186,16 @@ class ModelConfig: return self.hf_config.model_type in MM_PREFIX_LM_MODELS def get_head_size(self) -> int: - # TODO remove hard code - if self.is_deepseek_mla: - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) - if self.use_mla: - return self.hf_text_config.kv_lora_rank + qk_rope_head_dim - else: - qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim - - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): - return self.hf_text_config.attention_head_dim - if self.is_attention_free: return 0 - - # NOTE: Some configs may set head_dim=None in the config - if getattr(self.hf_text_config, "head_dim", None) is not None: - return self.hf_text_config.head_dim - - # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` - if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: - return self.hf_text_config.hidden_size_per_head - - # FIXME(woosuk): This may not be true for all models. - return ( - self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads - ) + return self.model_arch_config.head_size def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False) - ) - if not new_decoder_arch_falcon and getattr( - self.hf_text_config, "multi_query", False - ): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if self.hf_config.model_type == "mpt": - if "kv_n_heads" in self.hf_config.attn_config: - return self.hf_config.attn_config["kv_n_heads"] - return self.hf_config.num_attention_heads - if self.hf_config.model_type == "dbrx": - return getattr( - self.hf_config.attn_config, - "kv_n_heads", - self.hf_config.num_attention_heads, - ) - - if self.hf_config.model_type == "nemotron-nas": - for block in self.hf_config.block_configs: - if not block.attention.no_op: - return ( - self.hf_config.num_attention_heads - // block.attention.n_heads_in_group - ) - - raise RuntimeError("Couldn't determine number of kv heads") - if self.is_attention_free: return 0 - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads + return self.model_arch_config.total_num_kv_heads def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: """Returns the number of KV heads per GPU.""" @@ -1342,46 +1211,14 @@ class ModelConfig: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: - num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + num_heads = self.model_arch_config.total_num_attention_heads return num_heads // parallel_config.tensor_parallel_size def get_num_experts(self) -> int: - """Returns the number of experts in the model.""" - num_expert_names = [ - "num_experts", # Jamba - "moe_num_experts", # Dbrx - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) - if isinstance(num_experts, list): - # Ernie VL's remote code uses list[int]... - # The values are always the same so we just take the first one. - return num_experts[0] - # Coerce to 0 if explicitly set to None - return num_experts or 0 + return self.model_arch_config.num_experts def get_total_num_hidden_layers(self) -> int: - if ( - self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp" - or self.hf_config.model_type == "qwen3_next_mtp" - or self.hf_config.model_type == "pangu_ultra_moe_mtp" - ): - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 0 - ) - elif self.hf_config.model_type == "longcat_flash_mtp": - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 1 - ) - else: - total_num_hidden_layers = getattr( - self.hf_text_config, "num_hidden_layers", 0 - ) - return total_num_hidden_layers + return self.model_arch_config.total_num_hidden_layers def get_layers_start_end_indices( self, parallel_config: ParallelConfig @@ -1432,9 +1269,7 @@ class ModelConfig: self.hf_text_config, "layers_block_type", None ) if layers_block_type_value is not None: - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): + if self.model_arch_config.text_model_type == "zamba2": if attn_block_type: return sum( t == "hybrid" for t in layers_block_type_value[start:end] @@ -1745,6 +1580,7 @@ class ModelConfig: ) max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, + model_arch_config=self.model_arch_config, tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, @@ -1969,46 +1805,6 @@ def _check_valid_dtype(model_type: str, dtype: torch.dtype): return True -def _find_dtype( - model_id: str, - config: PretrainedConfig, - *, - revision: str | None, -): - # NOTE: getattr(config, "dtype", torch.float32) is not correct - # because config.dtype can be None. - config_dtype = getattr(config, "dtype", None) - - # Fallbacks for multi-modal models if the root config - # does not define dtype - if config_dtype is None: - config_dtype = getattr(config.get_text_config(), "dtype", None) - if config_dtype is None and hasattr(config, "vision_config"): - config_dtype = getattr(config.vision_config, "dtype", None) - if config_dtype is None and hasattr(config, "encoder_config"): - config_dtype = getattr(config.encoder_config, "dtype", None) - - # Try to read the dtype of the weights if they are in safetensors format - if config_dtype is None: - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) - - if repo_mt and (files_mt := repo_mt.files_metadata): - param_dtypes: set[torch.dtype] = { - _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE - } - - if param_dtypes: - return common_broadcastable_dtype(param_dtypes) - - if config_dtype is None: - config_dtype = torch.float32 - - return config_dtype - - def _resolve_auto_dtype( model_type: str, config_dtype: torch.dtype, @@ -2063,7 +1859,9 @@ def _get_and_verify_dtype( is_pooling_model: bool, revision: str | None = None, ) -> torch.dtype: - config_dtype = _find_dtype(model_id, config, revision=revision) + config_dtype = ModelArchConfigConvertorBase.get_torch_dtype( + config, model_id, revision=revision + ) model_type = config.model_type if isinstance(dtype, str): @@ -2126,6 +1924,7 @@ def _get_head_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, + model_arch_config: ModelArchitectureConfig, tokenizer_config: dict | None, max_model_len: int | None, disable_sliding_window: bool, @@ -2134,36 +1933,9 @@ def _get_and_verify_max_len( encoder_config: Any | None = None, ) -> int: """Get and verify the model's maximum length.""" - derived_max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Whisper - "max_target_positions", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - # Choose the smallest "max_length" from the possible keys - max_len_key = None - for key in possible_keys: - max_len = getattr(hf_config, key, None) - if max_len is not None: - max_len_key = key if max_len < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, max_len) - # For Command-R / Cohere, Cohere2 / Aya Vision models - if tmp_max_len := getattr(hf_config, "model_max_length", None): - max_len_key = "model_max_length" - derived_max_model_len = tmp_max_len + (derived_max_model_len, max_len_key) = ( + model_arch_config.derived_max_model_len_and_key + ) # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. @@ -2196,10 +1968,9 @@ def _get_and_verify_max_len( default_max_len = 2048 logger.warning( - "The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", - possible_keys, + "The model's config.json does not contain any of the keys " + "to determine the original maximum length of the model. " + "Assuming the model's maximum length is %d.", default_max_len, ) derived_max_model_len = default_max_len diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py new file mode 100644 index 0000000000000..a35288c2330bf --- /dev/null +++ b/vllm/config/model_arch.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class ModelArchitectureConfig: + """ + Configuration for model architecture that required by vLLM runtime + """ + + architectures: list[str] + """List of model architecture class names (e.g., ['LlamaForCausalLM']).""" + + model_type: str + """Model type identifier (e.g., 'llama', 'gpt_oss').""" + + text_model_type: str | None + """Text model type identifier (e.g., 'llama4_text').""" + + hidden_size: int + """Hidden size of the model.""" + + total_num_hidden_layers: int + """Number of hidden layers in the model.""" + + total_num_attention_heads: int + """Number of attention heads in the model.""" + + head_size: int + """Head dimension of the model.""" + + vocab_size: int + """Vocabulary size of the model.""" + + total_num_kv_heads: int + """Number of key value heads in the model.""" + + num_experts: int + """Number of experts in the model.""" + + quantization_config: dict[str, Any] | None + """Quantization configuration dictionary containing quantization parameters.""" + + torch_dtype: torch.dtype | str | None + """PyTorch data type for model weights (e.g., 'float16', 'bfloat16').""" + + is_multimodal_model: bool + """Whether the model is a multimodal model.""" + + is_deepseek_mla: bool + """Whether the model is a DeepSeek MLA model.""" + + derived_max_model_len_and_key: tuple[float, str | None] + """Derived maximum model length and key from the hf config.""" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py new file mode 100644 index 0000000000000..db5899eab5c12 --- /dev/null +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig + +from vllm import envs +from vllm.config.model_arch import ( + ModelArchitectureConfig, +) +from vllm.config.utils import getattr_iter +from vllm.logger import init_logger +from vllm.transformers_utils.config import ( + get_hf_text_config, + try_get_safetensors_metadata, +) +from vllm.utils.import_utils import LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype + +if TYPE_CHECKING: + import vllm.model_executor.models.registry as me_models_registry +else: + # Use lazy loading to avoid circular import + me_models_registry = LazyLoader( + "model_executor", globals(), "vllm.model_executor.models.registry" + ) + +logger = init_logger(__name__) + + +class ModelArchConfigConvertorBase: + def __init__(self, hf_config: PretrainedConfig): + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(hf_config) + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_hidden_layers", 0) + + def get_total_num_attention_heads(self) -> int: + return getattr(self.hf_text_config, "num_attention_heads", 0) + + def get_vocab_size(self) -> int: + return getattr(self.hf_text_config, "vocab_size", 0) + + def get_hidden_size(self) -> int: + return getattr(self.hf_text_config, "hidden_size", 0) + + def get_head_size(self) -> int: + if self.is_deepseek_mla(): + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + if not envs.VLLM_MLA_DISABLE: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: + return self.hf_text_config.hidden_size_per_head + + # FIXME(woosuk): This may not be true for all models. + return ( + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads + ) + + def get_total_num_kv_heads(self) -> int: + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + return self.hf_text_config.num_attention_heads + + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "num_experts", # Jamba + "moe_num_experts", # Dbrx + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + # Coerce to 0 if explicitly set to None + return num_experts or 0 + + @classmethod + def get_torch_dtype(cls, hf_config, model_id: str, revision: str | None): + # NOTE: getattr(config, "dtype", torch.float32) is not correct + # because config.dtype can be None. + config_dtype = getattr(hf_config, "dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define dtype + if config_dtype is None: + config_dtype = getattr(hf_config.get_text_config(), "dtype", None) + if config_dtype is None and hasattr(hf_config, "vision_config"): + config_dtype = getattr(hf_config.vision_config, "dtype", None) + if config_dtype is None and hasattr(hf_config, "encoder_config"): + config_dtype = getattr(hf_config.encoder_config, "dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + @classmethod + def _normalize_quantization_config(cls, config: PretrainedConfig): + quant_cfg = getattr(config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + + if quant_cfg is not None: + # Use the community standard 'quant_method' + quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names + quant_method = quant_method.replace( + "compressed_tensors", "compressed-tensors" + ) + + quant_cfg["quant_method"] = quant_method + + return quant_cfg + + @classmethod + def get_quantization_config(cls, hf_config: PretrainedConfig): + quant_cfg = cls._normalize_quantization_config(hf_config) + if quant_cfg is None and ( + text_config := getattr(hf_config, "text_config", None) + ): + # Check the text config as well for multi-modal models. + quant_cfg = cls._normalize_quantization_config(text_config) + return quant_cfg + + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in ( + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "deepseek_mtp", + "kimi_k2", + "kimi_linear", + "longcat_flash", + "pangu_ultra_moe", + "pangu_ultra_moe_mtp", + ): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == "eagle": + # if the model is an EAGLE module, check for the + # underlying architecture + return ( + self.hf_text_config.model.model_type + in ("deepseek_v2", "deepseek_v3", "deepseek_v32") + and self.hf_text_config.kv_lora_rank is not None + ) + return False + + def derive_max_model_len_and_key(self) -> tuple[float, str | None]: + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys + max_len_key = None + for key in possible_keys: + max_len = getattr(self.hf_text_config, key, None) + if max_len is not None: + if max_len < derived_max_model_len: + max_len_key = key + derived_max_model_len = min(derived_max_model_len, max_len) + + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(self.hf_text_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + return derived_max_model_len, max_len_key + + def is_multimodal_model(self) -> bool: + return any( + multi_model_arch in self.hf_config.architectures + for multi_model_arch in me_models_registry._MULTIMODAL_MODELS + ) + + def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: + model_arch_config = ModelArchitectureConfig( + architectures=getattr(self.hf_config, "architectures", []), + model_type=self.hf_config.model_type, + text_model_type=getattr(self.hf_text_config, "model_type", None), + hidden_size=self.get_hidden_size(), + total_num_hidden_layers=self.get_num_hidden_layers(), + total_num_attention_heads=self.get_total_num_attention_heads(), + head_size=self.get_head_size(), + vocab_size=self.get_vocab_size(), + total_num_kv_heads=self.get_total_num_kv_heads(), + num_experts=self.get_num_experts(), + quantization_config=self.get_quantization_config(self.hf_config), + torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision), + is_multimodal_model=self.is_multimodal_model(), + is_deepseek_mla=self.is_deepseek_mla(), + derived_max_model_len_and_key=self.derive_max_model_len_and_key(), + ) + + return model_arch_config + + +class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return getattr(self.hf_text_config, "attention_head_dim", 0) + + +class FalconModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + new_decoder_arch_falcon = getattr( + self.hf_text_config, "new_decoder_architecture", False + ) + + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + return 1 + + # Use the base implementation which checks n_head_kv, num_kv_heads, etc. + return super().get_total_num_kv_heads() + + +class MPTModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + if "kv_n_heads" in self.hf_text_config.attn_config: + return self.hf_text_config.attn_config["kv_n_heads"] + return self.hf_text_config.num_attention_heads + + +class DbrxModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + return getattr( + self.hf_text_config.attn_config, + "kv_n_heads", + self.hf_text_config.num_attention_heads, + ) + + +class NemotronNasModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + for block in self.hf_text_config.block_configs: + if not block.attention.no_op: + return ( + self.hf_text_config.num_attention_heads + // block.attention.n_heads_in_group + ) + raise RuntimeError("Couldn't determine number of kv heads") + + +class DeepSeekMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class MimoMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class GLM4MoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class ErnieMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) + + +# hf_config.model_type -> convertor class +MODEL_ARCH_CONFIG_CONVERTORS = { + "zamba2": Zamba2ModelArchConfigConvertor, + "mpt": MPTModelArchConfigConvertor, + "dbrx": DbrxModelArchConfigConvertor, + "falcon": FalconModelArchConfigConvertor, + "RefinedWeb": FalconModelArchConfigConvertor, + "RefinedWebModel": FalconModelArchConfigConvertor, + "nemotron-nas": NemotronNasModelArchConfigConvertor, + "deepseek_mtp": DeepSeekMTPModelArchConfigConvertor, + "qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor, + "mimo_mtp": MimoMTPModelArchConfigConvertor, + "glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor, + "ernie_mtp": ErnieMTPModelArchConfigConvertor, + "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, + "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, +} From 9b19e3b94fc4bc21ed0576105f5fba9955b2f092 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Fri, 5 Dec 2025 09:59:59 -0800 Subject: [PATCH 02/23] align with_hf_config Signed-off-by: Xingyu Liu --- vllm/config/model.py | 14 ++++++++------ vllm/config/vllm.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 7f6313005737b..84e59dfc83d2e 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -508,12 +508,7 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) - self.model_arch_config = None - convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( - hf_config.model_type, ModelArchConfigConvertorBase - ) - convertor = convertor_cls(hf_config) - self.model_arch_config = convertor.convert(self.model, self.revision) + self.model_arch_config = self.get_model_arch_config() architectures = self.architectures registry = self.registry @@ -717,6 +712,13 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() + def get_model_arch_config(self) -> ModelArchitectureConfig: + convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( + self.hf_config.model_type, ModelArchConfigConvertorBase + ) + convertor = convertor_cls(self.hf_config) + return convertor.convert(self.model, self.revision) + @field_validator("tokenizer_mode", mode="after") def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: return tokenizer_mode.lower() diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 614a3226cb711..f03d9d768c740 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -421,6 +421,7 @@ class VllmConfig: model_config = copy.deepcopy(self.model_config) model_config.hf_config = hf_config + model_config.model_arch_config = model_config.get_model_arch_config() return replace(self, model_config=model_config) From 0a4f4724eff09bd1b4aca498e335db5e775a3241 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Fri, 5 Dec 2025 10:17:49 -0800 Subject: [PATCH 03/23] remove dtype_original_type Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 51 ++++++++---------------- tests/config/test_model_arch_config.py | 8 ---- 2 files changed, 17 insertions(+), 42 deletions(-) diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json index f8fabf4bd9ef1..6916ec50d9743 100644 --- a/tests/config/model_arch_groundtruth.json +++ b/tests/config/model_arch_groundtruth.json @@ -14,8 +14,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "mosaicml/mpt-7b": { "architectures": [ @@ -32,8 +31,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "databricks/dbrx-instruct": { "architectures": [ @@ -50,8 +48,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "tiiuae/falcon-7b": { "architectures": [ @@ -68,8 +65,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "tiiuae/falcon-40b": { "architectures": [ @@ -86,8 +82,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_main_random": { "architectures": [ @@ -104,8 +99,7 @@ "num_experts": 72, "is_deepseek_mla": true, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_draft_random": { "architectures": [ @@ -122,8 +116,7 @@ "num_experts": 72, "is_deepseek_mla": true, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "Qwen/Qwen3-Next-80B-A3B-Instruct": { "architectures": [ @@ -140,8 +133,7 @@ "num_experts": 512, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "tiny-random/qwen3-next-moe": { "architectures": [ @@ -158,8 +150,7 @@ "num_experts": 32, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "zai-org/GLM-4.5": { "architectures": [ @@ -176,8 +167,7 @@ "num_experts": 160, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "baidu/ERNIE-4.5-21B-A3B-PT": { "architectures": [ @@ -194,8 +184,7 @@ "num_experts": 64, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "lmsys/gpt-oss-20b-bf16": { "architectures": [ @@ -212,8 +201,7 @@ "num_experts": 32, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "deepseek-ai/DeepSeek-V3.2-Exp": { "architectures": [ @@ -230,8 +218,7 @@ "num_experts": 256, "is_deepseek_mla": true, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "meta-llama/Llama-4-Scout-17B-16E-Instruct": { "architectures": [ @@ -248,8 +235,7 @@ "num_experts": 16, "is_deepseek_mla": false, "is_multimodal_model": true, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "nvidia/Llama-3_3-Nemotron-Super-49B-v1": { "architectures": [ @@ -266,8 +252,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "XiaomiMiMo/MiMo-7B-RL": { "architectures": [ @@ -284,8 +269,7 @@ "num_experts": 0, "is_deepseek_mla": false, "is_multimodal_model": false, - "dtype": "torch.bfloat16", - "dtype_original_type": "torch.dtype" + "dtype": "torch.bfloat16" }, "meituan-longcat/LongCat-Flash-Chat": { "architectures": [ @@ -302,7 +286,6 @@ "num_experts": 512, "is_deepseek_mla": true, "is_multimodal_model": false, - "dtype": "torch.float32", - "dtype_original_type": "torch.dtype" + "dtype": "torch.float32" } } diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 2900cd977efab..43750753ea514 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -3,8 +3,6 @@ import json from pathlib import Path -import torch - from vllm.config import ModelConfig @@ -67,12 +65,6 @@ def test_model_arch_config(): dtype = model_arch_config.torch_dtype assert str(dtype) == expected["dtype"] - if expected["dtype_original_type"] == "str": - assert isinstance(dtype, str) - elif expected["dtype_original_type"] == "torch.dtype": - assert isinstance(dtype, torch.dtype) - else: - raise ValueError(f"Unknown dtype_original_type: {expected['dtype']}") # Test that model_config methods return expected values assert model_config.architectures == expected["architectures"] From 1cf506d89ef9e4d1fa7c806aebaa2f9ed18e3e5a Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Mon, 8 Dec 2025 16:13:31 -0800 Subject: [PATCH 04/23] fix attention free models Signed-off-by: Xingyu Liu --- .../model_arch_config_convertor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index db5899eab5c12..999f46083c6e0 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -269,6 +269,22 @@ class ModelArchConfigConvertorBase: return model_arch_config +class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + +class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): def get_head_size(self) -> int: return getattr(self.hf_text_config, "attention_head_dim", 0) @@ -357,6 +373,9 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { + "mamba": MambaModelArchConfigConvertor, + "mamba2": MambaModelArchConfigConvertor, + "terratorch": TerratorchModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor, From c327dffce1d1b1096194a2babe811bb3c6749b50 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Mon, 8 Dec 2025 16:48:59 -0800 Subject: [PATCH 05/23] fix DummyConfig in tests Signed-off-by: Xingyu Liu --- tests/models/utils.py | 5 ++++- .../model_arch_config_convertor.py | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index d84b4b820533e..479c056d543ec 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -14,6 +14,9 @@ from vllm.config.model import ModelConfig, ModelDType, RunnerOption from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal.processing import InputProcessingContext from vllm.tokenizers import cached_tokenizer_from_config +from vllm.transformers_utils.model_arch_config_convertor import ( + ModelArchConfigConvertorBase, +) from .. import ci_envs from .registry import HF_EXAMPLE_MODELS @@ -488,7 +491,7 @@ def dummy_hf_overrides( # Only set MoE related config when the model has MoE layers. # Otherwise all models detected as MoE by _get_transformers_backend_cls. - if ModelConfig.get_num_experts(DummyConfig) > 0: + if ModelArchConfigConvertorBase.get_num_experts(text_config) > 0: update_dict.update( { "num_experts": num_experts, diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 999f46083c6e0..16fd5b6b6dcbd 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, final import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -88,7 +88,9 @@ class ModelArchConfigConvertorBase: return self.hf_text_config.num_attention_heads - def get_num_experts(self) -> int: + @final + @classmethod + def get_num_experts(cls, hf_text_config: PretrainedConfig) -> int: """Returns the number of experts in the model.""" num_expert_names = [ "num_experts", # Jamba @@ -96,7 +98,7 @@ class ModelArchConfigConvertorBase: "n_routed_experts", # DeepSeek "num_local_experts", # Mixtral ] - num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + num_experts = getattr_iter(hf_text_config, num_expert_names, 0) if isinstance(num_experts, list): # Ernie VL's remote code uses list[int]... # The values are always the same so we just take the first one. @@ -104,8 +106,11 @@ class ModelArchConfigConvertorBase: # Coerce to 0 if explicitly set to None return num_experts or 0 + @final @classmethod - def get_torch_dtype(cls, hf_config, model_id: str, revision: str | None): + def get_torch_dtype( + cls, hf_config: PretrainedConfig, model_id: str, revision: str | None + ): # NOTE: getattr(config, "dtype", torch.float32) is not correct # because config.dtype can be None. config_dtype = getattr(hf_config, "dtype", None) @@ -139,6 +144,7 @@ class ModelArchConfigConvertorBase: return config_dtype + @final @classmethod def _normalize_quantization_config(cls, config: PretrainedConfig): quant_cfg = getattr(config, "quantization_config", None) @@ -171,6 +177,7 @@ class ModelArchConfigConvertorBase: return quant_cfg + @final @classmethod def get_quantization_config(cls, hf_config: PretrainedConfig): quant_cfg = cls._normalize_quantization_config(hf_config) @@ -258,7 +265,7 @@ class ModelArchConfigConvertorBase: head_size=self.get_head_size(), vocab_size=self.get_vocab_size(), total_num_kv_heads=self.get_total_num_kv_heads(), - num_experts=self.get_num_experts(), + num_experts=self.get_num_experts(self.hf_text_config), quantization_config=self.get_quantization_config(self.hf_config), torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision), is_multimodal_model=self.is_multimodal_model(), From a8cc81a695297c51981bfb51157854ddbeb53ce7 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 00:32:57 -0800 Subject: [PATCH 06/23] add tests for attention free models Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 51 +++++++++++++++++++ tests/config/test_model_arch_config.py | 3 ++ .../model_arch_config_convertor.py | 3 +- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json index 6916ec50d9743..c6d321f6c3257 100644 --- a/tests/config/model_arch_groundtruth.json +++ b/tests/config/model_arch_groundtruth.json @@ -1,4 +1,55 @@ { + "state-spaces/mamba-130m-hf": { + "architectures": [ + "MambaForCausalLM" + ], + "model_type": "mamba", + "text_model_type": "mamba", + "hidden_size": 768, + "total_num_hidden_layers": 24, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 50280, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.float32" + }, + "mistralai/Mamba-Codestral-7B-v0.1": { + "architectures": [ + "Mamba2ForCausalLM" + ], + "model_type": "mamba", + "text_model_type": "mamba", + "hidden_size": 4096, + "total_num_hidden_layers": 64, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 32768, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16" + }, + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": { + "architectures": [ + "Terratorch" + ], + "model_type": "timm_wrapper", + "text_model_type": "timm_wrapper", + "hidden_size": 0, + "total_num_hidden_layers": 0, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 0, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": true, + "dtype": "torch.float32" + }, "Zyphra/Zamba2-7B-instruct": { "architectures": [ "Zamba2ForCausalLM" diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 43750753ea514..365cc1104ccaf 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -15,6 +15,9 @@ def test_model_arch_config(): "meituan-longcat/LongCat-Flash-Chat", ] models_to_test = [ + "state-spaces/mamba-130m-hf", + "mistralai/Mamba-Codestral-7B-v0.1", + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "Zyphra/Zamba2-7B-instruct", "mosaicml/mpt-7b", "databricks/dbrx-instruct", diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 16fd5b6b6dcbd..d1e28cbe558bb 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -381,8 +381,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, - "mamba2": MambaModelArchConfigConvertor, - "terratorch": TerratorchModelArchConfigConvertor, + "timm_wrapper": TerratorchModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor, From 48eeb1ffbad37e30d61c1e2909f141f98fa337a5 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 11:58:08 -0800 Subject: [PATCH 07/23] remove multimodal in model_arch_config Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 40 +++++++++---------- tests/config/test_model_arch_config.py | 2 - vllm/config/model_arch.py | 3 -- .../model_arch_config_convertor.py | 18 +-------- 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json index c6d321f6c3257..c3540ab5bdf03 100644 --- a/tests/config/model_arch_groundtruth.json +++ b/tests/config/model_arch_groundtruth.json @@ -13,7 +13,7 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.float32" }, "mistralai/Mamba-Codestral-7B-v0.1": { @@ -30,7 +30,7 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": { @@ -47,7 +47,7 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": true, + "supports_multimodal": true, "dtype": "torch.float32" }, "Zyphra/Zamba2-7B-instruct": { @@ -64,7 +64,7 @@ "total_num_kv_heads": 32, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "mosaicml/mpt-7b": { @@ -81,7 +81,7 @@ "total_num_kv_heads": 32, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "databricks/dbrx-instruct": { @@ -98,7 +98,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "tiiuae/falcon-7b": { @@ -115,7 +115,7 @@ "total_num_kv_heads": 1, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "tiiuae/falcon-40b": { @@ -132,7 +132,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_main_random": { @@ -149,7 +149,7 @@ "total_num_kv_heads": 32, "num_experts": 72, "is_deepseek_mla": true, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_draft_random": { @@ -166,7 +166,7 @@ "total_num_kv_heads": 32, "num_experts": 72, "is_deepseek_mla": true, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "Qwen/Qwen3-Next-80B-A3B-Instruct": { @@ -183,7 +183,7 @@ "total_num_kv_heads": 2, "num_experts": 512, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "tiny-random/qwen3-next-moe": { @@ -200,7 +200,7 @@ "total_num_kv_heads": 8, "num_experts": 32, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "zai-org/GLM-4.5": { @@ -217,7 +217,7 @@ "total_num_kv_heads": 8, "num_experts": 160, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "baidu/ERNIE-4.5-21B-A3B-PT": { @@ -234,7 +234,7 @@ "total_num_kv_heads": 4, "num_experts": 64, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "lmsys/gpt-oss-20b-bf16": { @@ -251,7 +251,7 @@ "total_num_kv_heads": 8, "num_experts": 32, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "deepseek-ai/DeepSeek-V3.2-Exp": { @@ -268,7 +268,7 @@ "total_num_kv_heads": 128, "num_experts": 256, "is_deepseek_mla": true, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "meta-llama/Llama-4-Scout-17B-16E-Instruct": { @@ -285,7 +285,7 @@ "total_num_kv_heads": 8, "num_experts": 16, "is_deepseek_mla": false, - "is_multimodal_model": true, + "supports_multimodal": true, "dtype": "torch.bfloat16" }, "nvidia/Llama-3_3-Nemotron-Super-49B-v1": { @@ -302,7 +302,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "XiaomiMiMo/MiMo-7B-RL": { @@ -319,7 +319,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.bfloat16" }, "meituan-longcat/LongCat-Flash-Chat": { @@ -336,7 +336,7 @@ "total_num_kv_heads": 64, "num_experts": 512, "is_deepseek_mla": true, - "is_multimodal_model": false, + "supports_multimodal": false, "dtype": "torch.float32" } } diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 365cc1104ccaf..90c550de0e3e5 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -64,8 +64,6 @@ def test_model_arch_config(): assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] assert model_arch_config.num_experts == expected["num_experts"] assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] - assert model_arch_config.is_multimodal_model == expected["is_multimodal_model"] - dtype = model_arch_config.torch_dtype assert str(dtype) == expected["dtype"] diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py index a35288c2330bf..6d9e32a24c5c8 100644 --- a/vllm/config/model_arch.py +++ b/vllm/config/model_arch.py @@ -53,9 +53,6 @@ class ModelArchitectureConfig: torch_dtype: torch.dtype | str | None """PyTorch data type for model weights (e.g., 'float16', 'bfloat16').""" - is_multimodal_model: bool - """Whether the model is a multimodal model.""" - is_deepseek_mla: bool """Whether the model is a DeepSeek MLA model.""" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index d1e28cbe558bb..d453a2395e66c 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, final +from typing import final import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -17,17 +17,8 @@ from vllm.transformers_utils.config import ( get_hf_text_config, try_get_safetensors_metadata, ) -from vllm.utils.import_utils import LazyLoader from vllm.utils.torch_utils import common_broadcastable_dtype -if TYPE_CHECKING: - import vllm.model_executor.models.registry as me_models_registry -else: - # Use lazy loading to avoid circular import - me_models_registry = LazyLoader( - "model_executor", globals(), "vllm.model_executor.models.registry" - ) - logger = init_logger(__name__) @@ -248,12 +239,6 @@ class ModelArchConfigConvertorBase: derived_max_model_len = tmp_max_len return derived_max_model_len, max_len_key - def is_multimodal_model(self) -> bool: - return any( - multi_model_arch in self.hf_config.architectures - for multi_model_arch in me_models_registry._MULTIMODAL_MODELS - ) - def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: model_arch_config = ModelArchitectureConfig( architectures=getattr(self.hf_config, "architectures", []), @@ -268,7 +253,6 @@ class ModelArchConfigConvertorBase: num_experts=self.get_num_experts(self.hf_text_config), quantization_config=self.get_quantization_config(self.hf_config), torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision), - is_multimodal_model=self.is_multimodal_model(), is_deepseek_mla=self.is_deepseek_mla(), derived_max_model_len_and_key=self.derive_max_model_len_and_key(), ) From 5870d362be004f6ee1d2230c80f2895758701b7c Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 13:56:46 -0800 Subject: [PATCH 08/23] support falcon_mamba Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 57 ++++++++++++------- vllm/config/model.py | 5 -- .../model_arch_config_convertor.py | 1 + 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json index c3540ab5bdf03..3401198ad7d56 100644 --- a/tests/config/model_arch_groundtruth.json +++ b/tests/config/model_arch_groundtruth.json @@ -13,7 +13,7 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.float32" }, "mistralai/Mamba-Codestral-7B-v0.1": { @@ -30,7 +30,7 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": { @@ -47,9 +47,26 @@ "total_num_kv_heads": 0, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": true, + "is_multimodal_model": true, "dtype": "torch.float32" }, + "tiiuae/falcon-mamba-7b-instruct": { + "architectures": [ + "FalconMambaForCausalLM" + ], + "model_type": "falcon_mamba", + "text_model_type": "falcon_mamba", + "hidden_size": 4096, + "total_num_hidden_layers": 64, + "total_num_attention_heads": 0, + "head_size": 0, + "vocab_size": 65024, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16" + }, "Zyphra/Zamba2-7B-instruct": { "architectures": [ "Zamba2ForCausalLM" @@ -64,7 +81,7 @@ "total_num_kv_heads": 32, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "mosaicml/mpt-7b": { @@ -81,7 +98,7 @@ "total_num_kv_heads": 32, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "databricks/dbrx-instruct": { @@ -98,7 +115,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "tiiuae/falcon-7b": { @@ -115,7 +132,7 @@ "total_num_kv_heads": 1, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "tiiuae/falcon-40b": { @@ -132,7 +149,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_main_random": { @@ -149,7 +166,7 @@ "total_num_kv_heads": 32, "num_experts": 72, "is_deepseek_mla": true, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "luccafong/deepseek_mtp_draft_random": { @@ -166,7 +183,7 @@ "total_num_kv_heads": 32, "num_experts": 72, "is_deepseek_mla": true, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "Qwen/Qwen3-Next-80B-A3B-Instruct": { @@ -183,7 +200,7 @@ "total_num_kv_heads": 2, "num_experts": 512, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "tiny-random/qwen3-next-moe": { @@ -200,7 +217,7 @@ "total_num_kv_heads": 8, "num_experts": 32, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "zai-org/GLM-4.5": { @@ -217,7 +234,7 @@ "total_num_kv_heads": 8, "num_experts": 160, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "baidu/ERNIE-4.5-21B-A3B-PT": { @@ -234,7 +251,7 @@ "total_num_kv_heads": 4, "num_experts": 64, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "lmsys/gpt-oss-20b-bf16": { @@ -251,7 +268,7 @@ "total_num_kv_heads": 8, "num_experts": 32, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "deepseek-ai/DeepSeek-V3.2-Exp": { @@ -268,7 +285,7 @@ "total_num_kv_heads": 128, "num_experts": 256, "is_deepseek_mla": true, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "meta-llama/Llama-4-Scout-17B-16E-Instruct": { @@ -285,7 +302,7 @@ "total_num_kv_heads": 8, "num_experts": 16, "is_deepseek_mla": false, - "supports_multimodal": true, + "is_multimodal_model": true, "dtype": "torch.bfloat16" }, "nvidia/Llama-3_3-Nemotron-Super-49B-v1": { @@ -302,7 +319,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "XiaomiMiMo/MiMo-7B-RL": { @@ -319,7 +336,7 @@ "total_num_kv_heads": 8, "num_experts": 0, "is_deepseek_mla": false, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.bfloat16" }, "meituan-longcat/LongCat-Flash-Chat": { @@ -336,7 +353,7 @@ "total_num_kv_heads": 64, "num_experts": 512, "is_deepseek_mla": true, - "supports_multimodal": false, + "is_multimodal_model": false, "dtype": "torch.float32" } } diff --git a/vllm/config/model.py b/vllm/config/model.py index 84e59dfc83d2e..370f5c9b11935 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1188,15 +1188,10 @@ class ModelConfig: return self.hf_config.model_type in MM_PREFIX_LM_MODELS def get_head_size(self) -> int: - if self.is_attention_free: - return 0 return self.model_arch_config.head_size def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" - if self.is_attention_free: - return 0 - return self.model_arch_config.total_num_kv_heads def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index d453a2395e66c..40cf438f4a804 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -365,6 +365,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, + "falcon_mamba": MambaModelArchConfigConvertor, "timm_wrapper": TerratorchModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, From b78c44f1d883f060cdc798925bf2eea785810d83 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 15:20:15 -0800 Subject: [PATCH 09/23] remove terratorch test Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 90c550de0e3e5..b024b5ebec83e 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -10,14 +10,17 @@ def test_model_arch_config(): trust_remote_code_models = [ "nvidia/Llama-3_3-Nemotron-Super-49B-v1", "XiaomiMiMo/MiMo-7B-RL", - # Not available online right now + # Excluded: Not available online right now # "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", "meituan-longcat/LongCat-Flash-Chat", ] models_to_test = [ "state-spaces/mamba-130m-hf", "mistralai/Mamba-Codestral-7B-v0.1", - "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + # Excluded: terratorch/torchgeo version mismatch in + # Async Engine, Inputs, Utils, Worker, Config Test (CPU) CI test environment + # (NonGeoDataset import error). + # "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "Zyphra/Zamba2-7B-instruct", "mosaicml/mpt-7b", "databricks/dbrx-instruct", @@ -29,7 +32,7 @@ def test_model_arch_config(): "tiny-random/qwen3-next-moe", "zai-org/GLM-4.5", "baidu/ERNIE-4.5-21B-A3B-PT", - # Select some models using base convertor for testing + # Models using base convertor "lmsys/gpt-oss-20b-bf16", "deepseek-ai/DeepSeek-V3.2-Exp", "meta-llama/Llama-4-Scout-17B-16E-Instruct", @@ -67,7 +70,7 @@ def test_model_arch_config(): dtype = model_arch_config.torch_dtype assert str(dtype) == expected["dtype"] - # Test that model_config methods return expected values + # Ensure model_config methods return expected values assert model_config.architectures == expected["architectures"] assert model_config.get_vocab_size() == expected["vocab_size"] assert model_config.get_hidden_size() == expected["hidden_size"] From f72949b2888925163da0885e56a6ff779a5003b9 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 17:04:58 -0800 Subject: [PATCH 10/23] architectures can be None Signed-off-by: Xingyu Liu --- vllm/config/model_arch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py index 6d9e32a24c5c8..1bf72fe91f646 100644 --- a/vllm/config/model_arch.py +++ b/vllm/config/model_arch.py @@ -17,8 +17,9 @@ class ModelArchitectureConfig: Configuration for model architecture that required by vLLM runtime """ - architectures: list[str] - """List of model architecture class names (e.g., ['LlamaForCausalLM']).""" + architectures: list[str] | None + """List of model architecture class names (e.g., ['LlamaForCausalLM']). + It can be None upon calling `vllm_config.with_hf_config(config.text_config)`""" model_type: str """Model type identifier (e.g., 'llama', 'gpt_oss').""" From aab35fc31c6a803a589d39a284afb057b39f8c0a Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 20:47:53 -0800 Subject: [PATCH 11/23] support medusa Signed-off-by: Xingyu Liu --- vllm/config/speculative.py | 3 +++ vllm/transformers_utils/model_arch_config_convertor.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index bf533bf14e55c..ad4057de834fb 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -401,6 +401,9 @@ class SpeculativeConfig: model_type="eagle", ) self.draft_model_config.hf_config = eagle_config + self.draft_model_config.model_arch_config = ( + self.draft_model_config.get_model_arch_config() + ) if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 40cf438f4a804..ed6ba0adb5e20 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -276,6 +276,14 @@ class TerratorchModelArchConfigConvertor(ModelArchConfigConvertorBase): return 0 +class MedusaModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return 0 + + def get_total_num_kv_heads(self) -> int: + return 0 + + class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): def get_head_size(self) -> int: return getattr(self.hf_text_config, "attention_head_dim", 0) @@ -367,6 +375,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, "falcon_mamba": MambaModelArchConfigConvertor, "timm_wrapper": TerratorchModelArchConfigConvertor, + "medusa": MedusaModelArchConfigConvertor, "zamba2": Zamba2ModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor, From 65c6d2565d2256c3e4e53c07526da53699f6c8d2 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 23:17:36 -0800 Subject: [PATCH 12/23] speculative tests Signed-off-by: Xingyu Liu --- ....json => base_model_arch_groundtruth.json} | 0 .../config/draft_model_arch_groundtruth.json | 87 +++++++++++++++++++ tests/config/test_model_arch_config.py | 74 +++++++++++++++- vllm/config/model.py | 2 +- .../model_arch_config_convertor.py | 5 +- 5 files changed, 161 insertions(+), 7 deletions(-) rename tests/config/{model_arch_groundtruth.json => base_model_arch_groundtruth.json} (100%) create mode 100644 tests/config/draft_model_arch_groundtruth.json diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/base_model_arch_groundtruth.json similarity index 100% rename from tests/config/model_arch_groundtruth.json rename to tests/config/base_model_arch_groundtruth.json diff --git a/tests/config/draft_model_arch_groundtruth.json b/tests/config/draft_model_arch_groundtruth.json new file mode 100644 index 0000000000000..dfe6f3d39e93b --- /dev/null +++ b/tests/config/draft_model_arch_groundtruth.json @@ -0,0 +1,87 @@ +{ + "abhigoyal/vllm-medusa-llama-68m-random": { + "architectures": [ + "MedusaModel" + ], + "model_type": "medusa", + "text_model_type": "medusa", + "hidden_size": 768, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 0, + "head_size": "Error: integer division or modulo by zero", + "vocab_size": 32000, + "total_num_kv_heads": 0, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.float32" + }, + "luccafong/deepseek_mtp_draft_random": { + "architectures": [ + "DeepSeekMTPModel" + ], + "model_type": "deepseek_mtp", + "text_model_type": "deepseek_mtp", + "hidden_size": 2560, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16" + }, + "eagle618/eagle-deepseek-v3-random": { + "architectures": [ + "EagleDeepSeekMTPModel" + ], + "model_type": "eagle", + "text_model_type": "deepseek_mtp", + "hidden_size": 2560, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "bfloat16" + }, + "yuhuili/EAGLE-LLaMA3-Instruct-8B": { + "architectures": [ + "EagleLlamaForCausalLM" + ], + "model_type": "eagle", + "text_model_type": "llama", + "hidden_size": 4096, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "float16" + }, + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": { + "architectures": [ + "Eagle3LlamaForCausalLM" + ], + "model_type": "eagle", + "text_model_type": "llama", + "hidden_size": 4096, + "total_num_hidden_layers": 1, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "float16" + } +} diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index b024b5ebec83e..6419461b930ce 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -3,10 +3,10 @@ import json from pathlib import Path -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig -def test_model_arch_config(): +def test_basic(): trust_remote_code_models = [ "nvidia/Llama-3_3-Nemotron-Super-49B-v1", "XiaomiMiMo/MiMo-7B-RL", @@ -38,7 +38,7 @@ def test_model_arch_config(): "meta-llama/Llama-4-Scout-17B-16E-Instruct", ] + trust_remote_code_models - groundtruth_path = Path(__file__).parent / "model_arch_groundtruth.json" + groundtruth_path = Path(__file__).parent / "base_model_arch_groundtruth.json" with open(groundtruth_path) as f: model_arch_groundtruth = json.load(f) @@ -81,3 +81,71 @@ def test_model_arch_config(): model_config.get_total_num_hidden_layers() == expected["total_num_hidden_layers"] ) + + +def test_draft_models(): + speculative_models = [ + ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), + ("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True), + ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), + ("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True), + ("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True), + ] + + groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json" + with open(groundtruth_path) as f: + model_arch_groundtruth = json.load(f) + + for target_model, draft_model, trust_remote_code in speculative_models: + print(f"testing {target_model=} {draft_model=}") + target_model_config = ModelConfig( + target_model, trust_remote_code=trust_remote_code + ) + speculative_config = { + "model": draft_model, + "num_speculative_tokens": 1, + "target_model_config": target_model_config, + "target_parallel_config": ParallelConfig(), + } + + speculative_config = SpeculativeConfig(**speculative_config) + model_config = speculative_config.draft_model_config + + model_arch_config = model_config.model_arch_config + expected = model_arch_groundtruth[draft_model] + assert model_arch_config.architectures == expected["architectures"] + assert model_arch_config.model_type == expected["model_type"] + assert model_arch_config.text_model_type == expected["text_model_type"] + assert model_arch_config.hidden_size == expected["hidden_size"] + assert ( + model_arch_config.total_num_hidden_layers + == expected["total_num_hidden_layers"] + ) + assert ( + model_arch_config.total_num_attention_heads + == expected["total_num_attention_heads"] + ) + + assert model_arch_config.vocab_size == expected["vocab_size"] + assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] + assert model_arch_config.num_experts == expected["num_experts"] + assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] + dtype = model_arch_config.torch_dtype + assert str(dtype) == expected["dtype"] + + # Ensure model_config methods return expected values + assert model_config.architectures == expected["architectures"] + assert model_config.get_vocab_size() == expected["vocab_size"] + assert model_config.get_hidden_size() == expected["hidden_size"] + assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] + assert model_config.get_num_experts() == expected["num_experts"] + assert ( + model_config.get_total_num_hidden_layers() + == expected["total_num_hidden_layers"] + ) + + if isinstance(expected["head_size"], int): + # Before model_arch_config is introduced, get_head_size() for medusa + # model config will throw out `integer division or modulo by zero` error. + assert model_arch_config.head_size == expected["head_size"] + assert model_config.get_head_size() == expected["head_size"] diff --git a/vllm/config/model.py b/vllm/config/model.py index 370f5c9b11935..5ed08bdbe949f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -716,7 +716,7 @@ class ModelConfig: convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( self.hf_config.model_type, ModelArchConfigConvertorBase ) - convertor = convertor_cls(self.hf_config) + convertor = convertor_cls(self.hf_config, self.hf_text_config) return convertor.convert(self.model, self.revision) @field_validator("tokenizer_mode", mode="after") diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index ed6ba0adb5e20..d785dce3d32e5 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -14,7 +14,6 @@ from vllm.config.model_arch import ( from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.config import ( - get_hf_text_config, try_get_safetensors_metadata, ) from vllm.utils.torch_utils import common_broadcastable_dtype @@ -23,9 +22,9 @@ logger = init_logger(__name__) class ModelArchConfigConvertorBase: - def __init__(self, hf_config: PretrainedConfig): + def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): self.hf_config = hf_config - self.hf_text_config = get_hf_text_config(hf_config) + self.hf_text_config = hf_text_config def get_num_hidden_layers(self) -> int: return getattr(self.hf_text_config, "num_hidden_layers", 0) From 0cd72dc43897405c3e2005f4a9e62376817e2bd3 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 23:18:29 -0800 Subject: [PATCH 13/23] speculative tests Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 6419461b930ce..d7f12e1d5a6f8 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -3,7 +3,7 @@ import json from pathlib import Path -from vllm.config import ModelConfig, SpeculativeConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig def test_basic(): @@ -86,10 +86,22 @@ def test_basic(): def test_draft_models(): speculative_models = [ ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), - ("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True), + ( + "luccafong/deepseek_mtp_main_random", + "luccafong/deepseek_mtp_draft_random", + True, + ), ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), - ("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True), - ("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True), + ( + "meta-llama/Meta-Llama-3-8B-Instruct", + "yuhuili/EAGLE-LLaMA3-Instruct-8B", + True, + ), + ( + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + True, + ), ] groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json" @@ -108,7 +120,7 @@ def test_draft_models(): "target_parallel_config": ParallelConfig(), } - speculative_config = SpeculativeConfig(**speculative_config) + speculative_config = SpeculativeConfig(**speculative_config) model_config = speculative_config.draft_model_config model_arch_config = model_config.model_arch_config @@ -145,7 +157,7 @@ def test_draft_models(): ) if isinstance(expected["head_size"], int): - # Before model_arch_config is introduced, get_head_size() for medusa + # Before model_arch_config is introduced, get_head_size() for medusa # model config will throw out `integer division or modulo by zero` error. assert model_arch_config.head_size == expected["head_size"] assert model_config.get_head_size() == expected["head_size"] From 5401f6529d3470db18291f9ef778993ffcc1cbd9 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 9 Dec 2025 23:25:02 -0800 Subject: [PATCH 14/23] refactor tests Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 259 ++++++++++++------------- 1 file changed, 120 insertions(+), 139 deletions(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index d7f12e1d5a6f8..13f409236822f 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -1,163 +1,144 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for ModelArchitectureConfig and its integration with ModelConfig.""" + import json from pathlib import Path +import pytest + from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig +BASE_TRUST_REMOTE_CODE_MODELS = { + "nvidia/Llama-3_3-Nemotron-Super-49B-v1", + "XiaomiMiMo/MiMo-7B-RL", + # Excluded: Not available online right now + # "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", + "meituan-longcat/LongCat-Flash-Chat", +} -def test_basic(): - trust_remote_code_models = [ - "nvidia/Llama-3_3-Nemotron-Super-49B-v1", - "XiaomiMiMo/MiMo-7B-RL", - # Excluded: Not available online right now - # "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", - "meituan-longcat/LongCat-Flash-Chat", - ] - models_to_test = [ - "state-spaces/mamba-130m-hf", - "mistralai/Mamba-Codestral-7B-v0.1", - # Excluded: terratorch/torchgeo version mismatch in - # Async Engine, Inputs, Utils, Worker, Config Test (CPU) CI test environment - # (NonGeoDataset import error). - # "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", - "Zyphra/Zamba2-7B-instruct", - "mosaicml/mpt-7b", - "databricks/dbrx-instruct", - "tiiuae/falcon-7b", - "tiiuae/falcon-40b", - "luccafong/deepseek_mtp_main_random", - "luccafong/deepseek_mtp_draft_random", - "Qwen/Qwen3-Next-80B-A3B-Instruct", - "tiny-random/qwen3-next-moe", - "zai-org/GLM-4.5", - "baidu/ERNIE-4.5-21B-A3B-PT", - # Models using base convertor - "lmsys/gpt-oss-20b-bf16", - "deepseek-ai/DeepSeek-V3.2-Exp", - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - ] + trust_remote_code_models +BASE_MODELS_TO_TEST = [ + "state-spaces/mamba-130m-hf", + "mistralai/Mamba-Codestral-7B-v0.1", + # Excluded: terratorch/torchgeo version mismatch in CPU CI environment + # (NonGeoDataset import error). Tested in model initialization tests. + # "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "Zyphra/Zamba2-7B-instruct", + "mosaicml/mpt-7b", + "databricks/dbrx-instruct", + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "luccafong/deepseek_mtp_main_random", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "tiny-random/qwen3-next-moe", + "zai-org/GLM-4.5", + "baidu/ERNIE-4.5-21B-A3B-PT", + # Models using base convertor + "lmsys/gpt-oss-20b-bf16", + "deepseek-ai/DeepSeek-V3.2-Exp", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", +] + list(BASE_TRUST_REMOTE_CODE_MODELS) - groundtruth_path = Path(__file__).parent / "base_model_arch_groundtruth.json" +# (target_model, draft_model, trust_remote_code) +SPECULATIVE_MODELS = [ + ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), + ("luccafong/deepseek_mtp_main_random", "luccafong/deepseek_mtp_draft_random", True), + ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), + ("meta-llama/Meta-Llama-3-8B-Instruct", "yuhuili/EAGLE-LLaMA3-Instruct-8B", True), + ("meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", True), +] + + +def _load_groundtruth(filename: str) -> dict: + """Load groundtruth JSON from the test directory.""" + groundtruth_path = Path(__file__).parent / filename with open(groundtruth_path) as f: - model_arch_groundtruth = json.load(f) + return json.load(f) - for model in models_to_test: - print(f"testing {model=}") - model_config = ModelConfig( - model, trust_remote_code=model in trust_remote_code_models - ) - model_arch_config = model_config.model_arch_config - expected = model_arch_groundtruth[model] - assert model_arch_config.architectures == expected["architectures"] - assert model_arch_config.model_type == expected["model_type"] - assert model_arch_config.text_model_type == expected["text_model_type"] - assert model_arch_config.hidden_size == expected["hidden_size"] - assert ( - model_arch_config.total_num_hidden_layers - == expected["total_num_hidden_layers"] - ) - assert ( - model_arch_config.total_num_attention_heads - == expected["total_num_attention_heads"] - ) +def _assert_model_arch_config( + model_arch_config, expected: dict, check_head_size: bool = True +): + """Assert model_arch_config matches expected values.""" + assert model_arch_config.architectures == expected["architectures"] + assert model_arch_config.model_type == expected["model_type"] + assert model_arch_config.text_model_type == expected["text_model_type"] + assert model_arch_config.hidden_size == expected["hidden_size"] + assert ( + model_arch_config.total_num_hidden_layers == expected["total_num_hidden_layers"] + ) + assert ( + model_arch_config.total_num_attention_heads + == expected["total_num_attention_heads"] + ) + assert model_arch_config.vocab_size == expected["vocab_size"] + assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] + assert model_arch_config.num_experts == expected["num_experts"] + assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] + assert str(model_arch_config.torch_dtype) == expected["dtype"] + + if check_head_size: assert model_arch_config.head_size == expected["head_size"] - assert model_arch_config.vocab_size == expected["vocab_size"] - assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] - assert model_arch_config.num_experts == expected["num_experts"] - assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] - dtype = model_arch_config.torch_dtype - assert str(dtype) == expected["dtype"] - # Ensure model_config methods return expected values - assert model_config.architectures == expected["architectures"] - assert model_config.get_vocab_size() == expected["vocab_size"] - assert model_config.get_hidden_size() == expected["hidden_size"] + +def _assert_model_config_methods( + model_config, expected: dict, check_head_size: bool = True +): + """Assert model_config methods return expected values.""" + assert model_config.architectures == expected["architectures"] + assert model_config.get_vocab_size() == expected["vocab_size"] + assert model_config.get_hidden_size() == expected["hidden_size"] + assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] + assert model_config.get_num_experts() == expected["num_experts"] + assert ( + model_config.get_total_num_hidden_layers() + == expected["total_num_hidden_layers"] + ) + + if check_head_size: assert model_config.get_head_size() == expected["head_size"] - assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] - assert model_config.get_num_experts() == expected["num_experts"] - assert ( - model_config.get_total_num_hidden_layers() - == expected["total_num_hidden_layers"] - ) -def test_draft_models(): - speculative_models = [ - ("JackFram/llama-68m", "abhigoyal/vllm-medusa-llama-68m-random", False), - ( - "luccafong/deepseek_mtp_main_random", - "luccafong/deepseek_mtp_draft_random", - True, - ), - ("eagle618/deepseek-v3-random", "eagle618/eagle-deepseek-v3-random", True), - ( - "meta-llama/Meta-Llama-3-8B-Instruct", - "yuhuili/EAGLE-LLaMA3-Instruct-8B", - True, - ), - ( - "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", - True, - ), - ] +@pytest.mark.parametrize("model", BASE_MODELS_TO_TEST) +def test_base_model_arch_config(model: str): + """Test model architecture config for base models.""" + groundtruth = _load_groundtruth("base_model_arch_groundtruth.json") + expected = groundtruth[model] - groundtruth_path = Path(__file__).parent / "draft_model_arch_groundtruth.json" - with open(groundtruth_path) as f: - model_arch_groundtruth = json.load(f) + model_config = ModelConfig( + model, trust_remote_code=model in BASE_TRUST_REMOTE_CODE_MODELS + ) - for target_model, draft_model, trust_remote_code in speculative_models: - print(f"testing {target_model=} {draft_model=}") - target_model_config = ModelConfig( - target_model, trust_remote_code=trust_remote_code - ) - speculative_config = { - "model": draft_model, - "num_speculative_tokens": 1, - "target_model_config": target_model_config, - "target_parallel_config": ParallelConfig(), - } + _assert_model_arch_config(model_config.model_arch_config, expected) + _assert_model_config_methods(model_config, expected) - speculative_config = SpeculativeConfig(**speculative_config) - model_config = speculative_config.draft_model_config - model_arch_config = model_config.model_arch_config - expected = model_arch_groundtruth[draft_model] - assert model_arch_config.architectures == expected["architectures"] - assert model_arch_config.model_type == expected["model_type"] - assert model_arch_config.text_model_type == expected["text_model_type"] - assert model_arch_config.hidden_size == expected["hidden_size"] - assert ( - model_arch_config.total_num_hidden_layers - == expected["total_num_hidden_layers"] - ) - assert ( - model_arch_config.total_num_attention_heads - == expected["total_num_attention_heads"] - ) +@pytest.mark.parametrize( + "target_model,draft_model,trust_remote_code", SPECULATIVE_MODELS +) +def test_draft_model_arch_config( + target_model: str, draft_model: str, trust_remote_code: bool +): + """Test model architecture config for draft/speculative models.""" + groundtruth = _load_groundtruth("draft_model_arch_groundtruth.json") + expected = groundtruth[draft_model] - assert model_arch_config.vocab_size == expected["vocab_size"] - assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] - assert model_arch_config.num_experts == expected["num_experts"] - assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] - dtype = model_arch_config.torch_dtype - assert str(dtype) == expected["dtype"] + target_model_config = ModelConfig(target_model, trust_remote_code=trust_remote_code) + speculative_config = SpeculativeConfig( + model=draft_model, + num_speculative_tokens=1, + target_model_config=target_model_config, + target_parallel_config=ParallelConfig(), + ) + model_config = speculative_config.draft_model_config - # Ensure model_config methods return expected values - assert model_config.architectures == expected["architectures"] - assert model_config.get_vocab_size() == expected["vocab_size"] - assert model_config.get_hidden_size() == expected["hidden_size"] - assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] - assert model_config.get_num_experts() == expected["num_experts"] - assert ( - model_config.get_total_num_hidden_layers() - == expected["total_num_hidden_layers"] - ) + # For medusa models, head_size may cause division by zero before + # model_arch_config was introduced, so we conditionally check it + check_head_size = isinstance(expected["head_size"], int) - if isinstance(expected["head_size"], int): - # Before model_arch_config is introduced, get_head_size() for medusa - # model config will throw out `integer division or modulo by zero` error. - assert model_arch_config.head_size == expected["head_size"] - assert model_config.get_head_size() == expected["head_size"] + _assert_model_arch_config( + model_config.model_arch_config, expected, check_head_size=check_head_size + ) + _assert_model_config_methods( + model_config, expected, check_head_size=check_head_size + ) From 0f85cfdc659ee120e51741bd1d69775ca0dacafa Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 16 Dec 2025 11:58:06 -0800 Subject: [PATCH 15/23] update with main Signed-off-by: Xingyu Liu --- .../model_arch_config_convertor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index d785dce3d32e5..705f5207b3687 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -26,6 +26,9 @@ class ModelArchConfigConvertorBase: self.hf_config = hf_config self.hf_text_config = hf_text_config + def get_architectures(self) -> list[str]: + return getattr(self.hf_config, "architectures", []) + def get_num_hidden_layers(self) -> int: return getattr(self.hf_text_config, "num_hidden_layers", 0) @@ -240,7 +243,7 @@ class ModelArchConfigConvertorBase: def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: model_arch_config = ModelArchitectureConfig( - architectures=getattr(self.hf_config, "architectures", []), + architectures=self.get_architectures(), model_type=self.hf_config.model_type, text_model_type=getattr(self.hf_text_config, "model_type", None), hidden_size=self.get_hidden_size(), @@ -331,7 +334,15 @@ class NemotronNasModelArchConfigConvertor(ModelArchConfigConvertorBase): self.hf_text_config.num_attention_heads // block.attention.n_heads_in_group ) - raise RuntimeError("Couldn't determine number of kv heads") + raise RuntimeError( + "Could not determine the number of key-value attention heads " + "from model configuration. " + f"Architecture: {self.get_architectures()}. " + "This usually indicates an unsupported model architecture or " + "missing configuration. " + "Please check if your model is supported at: " + "https://docs.vllm.ai/en/latest/models/supported_models.html" + ) class DeepSeekMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): From f5348c8174ee7fbad54fda10465c62a16563b5a9 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 16 Dec 2025 12:44:34 -0800 Subject: [PATCH 16/23] fix precommit Signed-off-by: Xingyu Liu --- vllm/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 5f27cf0572866..63e7fbf38b3ce 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -608,7 +608,7 @@ class ModelConfig: ) convertor = convertor_cls(self.hf_config, self.hf_text_config) return convertor.convert(self.model, self.revision) - + @field_validator("tokenizer", "max_model_len", mode="wrap") @classmethod def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: From 78d47494df0e6ca6e0e73f9c5e9e70c74ed2252b Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Mon, 22 Dec 2025 22:23:35 -0800 Subject: [PATCH 17/23] sync with #30957 Signed-off-by: Xingyu Liu --- .../model_arch_config_convertor.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 705f5207b3687..3caf86bfa40d2 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -150,12 +150,18 @@ class ModelArchConfigConvertorBase: producer_name = quant_cfg.get("producer", {}).get("name") if producer_name == "modelopt": quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + if quant_algo is not None: + quant_algo_upper = str(quant_algo).upper() + if quant_algo_upper in { + "FP8", + "FP8_PER_CHANNEL_PER_TOKEN", + "FP8_PB_WO", + }: + quant_cfg["quant_method"] = "modelopt" + elif quant_algo_upper == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + else: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") if quant_cfg is not None: # Use the community standard 'quant_method' From 5bde69c2b9199fe9a84462632a217003f561957e Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Mon, 22 Dec 2025 23:08:08 -0800 Subject: [PATCH 18/23] remove convertor dependency on model&revision, only torch_dtype is classmethod Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 14 +++++++--- tests/models/utils.py | 9 ++----- tests/v1/metrics/test_perf_metrics.py | 10 +++++++ vllm/config/model.py | 17 +++++++----- vllm/config/model_arch.py | 4 --- .../model_arch_config_convertor.py | 27 +++++++------------ 6 files changed, 44 insertions(+), 37 deletions(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 13f409236822f..5531963b05dcd 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -8,6 +8,9 @@ from pathlib import Path import pytest from vllm.config import ModelConfig, ParallelConfig, SpeculativeConfig +from vllm.transformers_utils.model_arch_config_convertor import ( + ModelArchConfigConvertorBase, +) BASE_TRUST_REMOTE_CODE_MODELS = { "nvidia/Llama-3_3-Nemotron-Super-49B-v1", @@ -57,9 +60,10 @@ def _load_groundtruth(filename: str) -> dict: def _assert_model_arch_config( - model_arch_config, expected: dict, check_head_size: bool = True + model_config, expected: dict, check_head_size: bool = True ): """Assert model_arch_config matches expected values.""" + model_arch_config = model_config.model_arch_config assert model_arch_config.architectures == expected["architectures"] assert model_arch_config.model_type == expected["model_type"] assert model_arch_config.text_model_type == expected["text_model_type"] @@ -75,7 +79,11 @@ def _assert_model_arch_config( assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] assert model_arch_config.num_experts == expected["num_experts"] assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] - assert str(model_arch_config.torch_dtype) == expected["dtype"] + + torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype( + model_config.hf_config, model_config.model_id, revision=model_config.revision + ) + assert str(torch_dtype) == expected["dtype"] if check_head_size: assert model_arch_config.head_size == expected["head_size"] @@ -109,7 +117,7 @@ def test_base_model_arch_config(model: str): model, trust_remote_code=model in BASE_TRUST_REMOTE_CODE_MODELS ) - _assert_model_arch_config(model_config.model_arch_config, expected) + _assert_model_arch_config(model_config, expected) _assert_model_config_methods(model_config, expected) diff --git a/tests/models/utils.py b/tests/models/utils.py index 479c056d543ec..ed1625d8d591a 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -14,9 +14,6 @@ from vllm.config.model import ModelConfig, ModelDType, RunnerOption from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal.processing import InputProcessingContext from vllm.tokenizers import cached_tokenizer_from_config -from vllm.transformers_utils.model_arch_config_convertor import ( - ModelArchConfigConvertorBase, -) from .. import ci_envs from .registry import HF_EXAMPLE_MODELS @@ -486,12 +483,10 @@ def dummy_hf_overrides( "num_kv_shared_layers": 1, } - class DummyConfig: - hf_text_config = text_config - + model_arch_config = ModelConfig.get_model_arch_config(hf_config, text_config) # Only set MoE related config when the model has MoE layers. # Otherwise all models detected as MoE by _get_transformers_backend_cls. - if ModelArchConfigConvertorBase.get_num_experts(text_config) > 0: + if model_arch_config.num_experts > 0: update_dict.update( { "num_experts": num_experts, diff --git a/tests/v1/metrics/test_perf_metrics.py b/tests/v1/metrics/test_perf_metrics.py index b6cda7bef3d41..e3846a7a3ef16 100644 --- a/tests/v1/metrics/test_perf_metrics.py +++ b/tests/v1/metrics/test_perf_metrics.py @@ -16,6 +16,10 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig from vllm.config.model import ModelConfig, get_hf_text_config +from vllm.transformers_utils.model_arch_config_convertor import ( + MODEL_ARCH_CONFIG_CONVERTORS, + ModelArchConfigConvertorBase, +) from vllm.v1.metrics.perf import ( AttentionMetrics, BaseConfigParser, @@ -33,6 +37,12 @@ class MockModelConfig: def __init__(self, hf_config, dtype): self.hf_config = hf_config self.hf_text_config = get_hf_text_config(hf_config) + convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( + self.hf_config.model_type, ModelArchConfigConvertorBase + ) + self.model_arch_config = convertor_cls( + self.hf_config, self.hf_text_config + ).convert() self.dtype = dtype self.is_attention_free = False diff --git a/vllm/config/model.py b/vllm/config/model.py index ae80bc7b09b6d..2062377d34640 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -484,7 +484,9 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) - self.model_arch_config = self.get_model_arch_config() + self.model_arch_config = self.get_model_arch_config( + self.hf_config, self.hf_text_config + ) architectures = self.architectures registry = self.registry @@ -602,12 +604,15 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() - def get_model_arch_config(self) -> ModelArchitectureConfig: + @classmethod + def get_model_arch_config( + cls, hf_config, hf_text_config + ) -> ModelArchitectureConfig: convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( - self.hf_config.model_type, ModelArchConfigConvertorBase + hf_config.model_type, ModelArchConfigConvertorBase ) - convertor = convertor_cls(self.hf_config, self.hf_text_config) - return convertor.convert(self.model, self.revision) + convertor = convertor_cls(hf_config, hf_text_config) + return convertor.convert() @field_validator("tokenizer", "max_model_len", mode="wrap") @classmethod @@ -850,7 +855,7 @@ class ModelConfig: self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = ModelArchConfigConvertorBase.get_quantization_config(self.hf_config) + quant_cfg = self.model_arch_config.quantization_config if quant_cfg is not None: quant_method = quant_cfg["quant_method"] diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py index 1bf72fe91f646..d55e2a3399b39 100644 --- a/vllm/config/model_arch.py +++ b/vllm/config/model_arch.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any -import torch from pydantic import ConfigDict from pydantic.dataclasses import dataclass @@ -51,9 +50,6 @@ class ModelArchitectureConfig: quantization_config: dict[str, Any] | None """Quantization configuration dictionary containing quantization parameters.""" - torch_dtype: torch.dtype | str | None - """PyTorch data type for model weights (e.g., 'float16', 'bfloat16').""" - is_deepseek_mla: bool """Whether the model is a DeepSeek MLA model.""" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 3caf86bfa40d2..1aa12345ec588 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -81,9 +81,7 @@ class ModelArchConfigConvertorBase: return self.hf_text_config.num_attention_heads - @final - @classmethod - def get_num_experts(cls, hf_text_config: PretrainedConfig) -> int: + def get_num_experts(self) -> int: """Returns the number of experts in the model.""" num_expert_names = [ "num_experts", # Jamba @@ -91,7 +89,7 @@ class ModelArchConfigConvertorBase: "n_routed_experts", # DeepSeek "num_local_experts", # Mixtral ] - num_experts = getattr_iter(hf_text_config, num_expert_names, 0) + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) if isinstance(num_experts, list): # Ernie VL's remote code uses list[int]... # The values are always the same so we just take the first one. @@ -137,9 +135,7 @@ class ModelArchConfigConvertorBase: return config_dtype - @final - @classmethod - def _normalize_quantization_config(cls, config: PretrainedConfig): + def _normalize_quantization_config(self, config: PretrainedConfig): quant_cfg = getattr(config, "quantization_config", None) if quant_cfg is None: # compressed-tensors uses a "compression_config" key @@ -176,15 +172,13 @@ class ModelArchConfigConvertorBase: return quant_cfg - @final - @classmethod - def get_quantization_config(cls, hf_config: PretrainedConfig): - quant_cfg = cls._normalize_quantization_config(hf_config) + def get_quantization_config(self): + quant_cfg = self._normalize_quantization_config(self.hf_config) if quant_cfg is None and ( - text_config := getattr(hf_config, "text_config", None) + text_config := getattr(self.hf_config, "text_config", None) ): # Check the text config as well for multi-modal models. - quant_cfg = cls._normalize_quantization_config(text_config) + quant_cfg = self._normalize_quantization_config(text_config) return quant_cfg def is_deepseek_mla(self) -> bool: @@ -247,7 +241,7 @@ class ModelArchConfigConvertorBase: derived_max_model_len = tmp_max_len return derived_max_model_len, max_len_key - def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: + def convert(self) -> ModelArchitectureConfig: model_arch_config = ModelArchitectureConfig( architectures=self.get_architectures(), model_type=self.hf_config.model_type, @@ -258,9 +252,8 @@ class ModelArchConfigConvertorBase: head_size=self.get_head_size(), vocab_size=self.get_vocab_size(), total_num_kv_heads=self.get_total_num_kv_heads(), - num_experts=self.get_num_experts(self.hf_text_config), - quantization_config=self.get_quantization_config(self.hf_config), - torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision), + num_experts=self.get_num_experts(), + quantization_config=self.get_quantization_config(), is_deepseek_mla=self.is_deepseek_mla(), derived_max_model_len_and_key=self.derive_max_model_len_and_key(), ) From 1c3db5611a596b7e1e97a9ed6ba81bdedeedd2c1 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Mon, 22 Dec 2025 23:12:50 -0800 Subject: [PATCH 19/23] make get_model_arch_config not classmethod Signed-off-by: Xingyu Liu --- tests/models/utils.py | 6 +++++- vllm/config/model.py | 11 ++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index ed1625d8d591a..8e3a86f426d5d 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -483,7 +483,11 @@ def dummy_hf_overrides( "num_kv_shared_layers": 1, } - model_arch_config = ModelConfig.get_model_arch_config(hf_config, text_config) + class DummyConfig: + hf_config = hf_config + hf_text_config = text_config + + model_arch_config = ModelConfig.get_model_arch_config(DummyConfig) # Only set MoE related config when the model has MoE layers. # Otherwise all models detected as MoE by _get_transformers_backend_cls. if model_arch_config.num_experts > 0: diff --git a/vllm/config/model.py b/vllm/config/model.py index 2062377d34640..dfe191199f1a9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -484,9 +484,7 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) - self.model_arch_config = self.get_model_arch_config( - self.hf_config, self.hf_text_config - ) + self.model_arch_config = self.get_model_arch_config() architectures = self.architectures registry = self.registry @@ -604,14 +602,13 @@ class ModelConfig: self._verify_cuda_graph() self._verify_bnb_config() - @classmethod def get_model_arch_config( - cls, hf_config, hf_text_config + self, ) -> ModelArchitectureConfig: convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( - hf_config.model_type, ModelArchConfigConvertorBase + self.hf_config.model_type, ModelArchConfigConvertorBase ) - convertor = convertor_cls(hf_config, hf_text_config) + convertor = convertor_cls(self.hf_config, self.hf_text_config) return convertor.convert() @field_validator("tokenizer", "max_model_len", mode="wrap") From 441d7355a28353d1c1897ceee7d432d3702df63e Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 23 Dec 2025 15:33:10 -0800 Subject: [PATCH 20/23] fix tests Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 2 +- tests/models/utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 5531963b05dcd..da699fed040de 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -81,7 +81,7 @@ def _assert_model_arch_config( assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype( - model_config.hf_config, model_config.model_id, revision=model_config.revision + model_config.hf_config, model_config.model, revision=model_config.revision ) assert str(torch_dtype) == expected["dtype"] diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e3a86f426d5d..63dd07b827dbb 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -483,8 +483,10 @@ def dummy_hf_overrides( "num_kv_shared_layers": 1, } + _hf_config = hf_config + class DummyConfig: - hf_config = hf_config + hf_config = _hf_config hf_text_config = text_config model_arch_config = ModelConfig.get_model_arch_config(DummyConfig) From e1b6bfa8244b2e53ce23521c32983de453c38fd2 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 23 Dec 2025 16:02:40 -0800 Subject: [PATCH 21/23] sync with #29788 Signed-off-by: Xingyu Liu --- .../model_arch_config_convertor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 1aa12345ec588..dc067a09419b7 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -74,12 +74,12 @@ class ModelArchConfigConvertorBase: # For ChatGLM: "multi_query_group_num", ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - return self.hf_text_config.num_attention_heads + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + default_factory = lambda: self.hf_text_config.num_attention_heads + return getattr_iter( + self.hf_text_config, attributes, default_factory=default_factory + ) def get_num_experts(self) -> int: """Returns the number of experts in the model.""" From 0d143e4be06b94381205b2d8bd7dba3a0147c630 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Tue, 23 Dec 2025 16:11:36 -0800 Subject: [PATCH 22/23] precommit fix Signed-off-by: Xingyu Liu --- vllm/config/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index cb75080ffa5d4..30b3d017e1d43 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -10,8 +10,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass -from transformers.configuration_utils import ALLOWED_LAYER_TYPES -from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum @@ -21,7 +19,7 @@ from vllm.config.model_arch import ( from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import config +from vllm.config.utils import config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( From 58eb6c43dde94a90ae5dd9f77369747147ac1497 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Wed, 24 Dec 2025 00:14:18 -0800 Subject: [PATCH 23/23] remove mosaicml/mpt-7b in tests Signed-off-by: Xingyu Liu --- tests/config/test_model_arch_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index da699fed040de..f123c4336b6c0 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -27,7 +27,8 @@ BASE_MODELS_TO_TEST = [ # (NonGeoDataset import error). Tested in model initialization tests. # "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", "Zyphra/Zamba2-7B-instruct", - "mosaicml/mpt-7b", + # FIXME: mosaicml/mpt-7b has been deleted + # "mosaicml/mpt-7b", "databricks/dbrx-instruct", "tiiuae/falcon-7b", "tiiuae/falcon-40b",