From ec811683302f601e476a8df760549cd67fe0c389 Mon Sep 17 00:00:00 2001 From: Xingyu Liu Date: Fri, 5 Dec 2025 01:46:27 -0800 Subject: [PATCH] introduce model arch config Signed-off-by: Xingyu Liu --- tests/config/model_arch_groundtruth.json | 308 +++++++++++++++ tests/config/test_model_arch_config.py | 87 ++++ tests/test_config.py | 1 + vllm/config/model.py | 311 ++------------- vllm/config/model_arch.py | 63 +++ .../model_arch_config_convertor.py | 374 ++++++++++++++++++ 6 files changed, 874 insertions(+), 270 deletions(-) create mode 100644 tests/config/model_arch_groundtruth.json create mode 100644 tests/config/test_model_arch_config.py create mode 100644 vllm/config/model_arch.py create mode 100644 vllm/transformers_utils/model_arch_config_convertor.py diff --git a/tests/config/model_arch_groundtruth.json b/tests/config/model_arch_groundtruth.json new file mode 100644 index 0000000000000..f8fabf4bd9ef1 --- /dev/null +++ b/tests/config/model_arch_groundtruth.json @@ -0,0 +1,308 @@ +{ + "Zyphra/Zamba2-7B-instruct": { + "architectures": [ + "Zamba2ForCausalLM" + ], + "model_type": "zamba2", + "text_model_type": "zamba2", + "hidden_size": 3584, + "total_num_hidden_layers": 81, + "total_num_attention_heads": 32, + "head_size": 224, + "vocab_size": 32000, + "total_num_kv_heads": 32, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "mosaicml/mpt-7b": { + "architectures": [ + "MPTForCausalLM" + ], + "model_type": "mpt", + "text_model_type": "mpt", + "hidden_size": 4096, + "total_num_hidden_layers": 32, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 50432, + "total_num_kv_heads": 32, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "databricks/dbrx-instruct": { + "architectures": [ + "DbrxForCausalLM" + ], + "model_type": "dbrx", + "text_model_type": "dbrx", + "hidden_size": 6144, + "total_num_hidden_layers": 40, + "total_num_attention_heads": 48, + "head_size": 128, + "vocab_size": 100352, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiiuae/falcon-7b": { + "architectures": [ + "FalconForCausalLM" + ], + "model_type": "falcon", + "text_model_type": "falcon", + "hidden_size": 4544, + "total_num_hidden_layers": 32, + "total_num_attention_heads": 71, + "head_size": 64, + "vocab_size": 65024, + "total_num_kv_heads": 1, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiiuae/falcon-40b": { + "architectures": [ + "FalconForCausalLM" + ], + "model_type": "falcon", + "text_model_type": "falcon", + "hidden_size": 8192, + "total_num_hidden_layers": 60, + "total_num_attention_heads": 128, + "head_size": 64, + "vocab_size": 65024, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "luccafong/deepseek_mtp_main_random": { + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "model_type": "deepseek_v3", + "text_model_type": "deepseek_v3", + "hidden_size": 2560, + "total_num_hidden_layers": 5, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "luccafong/deepseek_mtp_draft_random": { + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "model_type": "deepseek_v3", + "text_model_type": "deepseek_v3", + "hidden_size": 2560, + "total_num_hidden_layers": 10, + "total_num_attention_heads": 32, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 32, + "num_experts": 72, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "Qwen/Qwen3-Next-80B-A3B-Instruct": { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "model_type": "qwen3_next", + "text_model_type": "qwen3_next", + "hidden_size": 2048, + "total_num_hidden_layers": 48, + "total_num_attention_heads": 16, + "head_size": 256, + "vocab_size": 151936, + "total_num_kv_heads": 2, + "num_experts": 512, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "tiny-random/qwen3-next-moe": { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "model_type": "qwen3_next", + "text_model_type": "qwen3_next", + "hidden_size": 8, + "total_num_hidden_layers": 4, + "total_num_attention_heads": 16, + "head_size": 32, + "vocab_size": 151936, + "total_num_kv_heads": 8, + "num_experts": 32, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "zai-org/GLM-4.5": { + "architectures": [ + "Glm4MoeForCausalLM" + ], + "model_type": "glm4_moe", + "text_model_type": "glm4_moe", + "hidden_size": 5120, + "total_num_hidden_layers": 92, + "total_num_attention_heads": 96, + "head_size": 128, + "vocab_size": 151552, + "total_num_kv_heads": 8, + "num_experts": 160, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "baidu/ERNIE-4.5-21B-A3B-PT": { + "architectures": [ + "Ernie4_5_MoeForCausalLM" + ], + "model_type": "ernie4_5_moe", + "text_model_type": "ernie4_5_moe", + "hidden_size": 2560, + "total_num_hidden_layers": 28, + "total_num_attention_heads": 20, + "head_size": 128, + "vocab_size": 103424, + "total_num_kv_heads": 4, + "num_experts": 64, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "lmsys/gpt-oss-20b-bf16": { + "architectures": [ + "GptOssForCausalLM" + ], + "model_type": "gpt_oss", + "text_model_type": "gpt_oss", + "hidden_size": 2880, + "total_num_hidden_layers": 24, + "total_num_attention_heads": 64, + "head_size": 64, + "vocab_size": 201088, + "total_num_kv_heads": 8, + "num_experts": 32, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "deepseek-ai/DeepSeek-V3.2-Exp": { + "architectures": [ + "DeepseekV32ForCausalLM" + ], + "model_type": "deepseek_v32", + "text_model_type": "deepseek_v32", + "hidden_size": 7168, + "total_num_hidden_layers": 61, + "total_num_attention_heads": 128, + "head_size": 576, + "vocab_size": 129280, + "total_num_kv_heads": 128, + "num_experts": 256, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "meta-llama/Llama-4-Scout-17B-16E-Instruct": { + "architectures": [ + "Llama4ForConditionalGeneration" + ], + "model_type": "llama4", + "text_model_type": "llama4_text", + "hidden_size": 5120, + "total_num_hidden_layers": 48, + "total_num_attention_heads": 40, + "head_size": 128, + "vocab_size": 202048, + "total_num_kv_heads": 8, + "num_experts": 16, + "is_deepseek_mla": false, + "is_multimodal_model": true, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "nvidia/Llama-3_3-Nemotron-Super-49B-v1": { + "architectures": [ + "DeciLMForCausalLM" + ], + "model_type": "nemotron-nas", + "text_model_type": "nemotron-nas", + "hidden_size": 8192, + "total_num_hidden_layers": 80, + "total_num_attention_heads": 64, + "head_size": 128, + "vocab_size": 128256, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "XiaomiMiMo/MiMo-7B-RL": { + "architectures": [ + "MiMoForCausalLM" + ], + "model_type": "mimo", + "text_model_type": "mimo", + "hidden_size": 4096, + "total_num_hidden_layers": 36, + "total_num_attention_heads": 32, + "head_size": 128, + "vocab_size": 151680, + "total_num_kv_heads": 8, + "num_experts": 0, + "is_deepseek_mla": false, + "is_multimodal_model": false, + "dtype": "torch.bfloat16", + "dtype_original_type": "torch.dtype" + }, + "meituan-longcat/LongCat-Flash-Chat": { + "architectures": [ + "LongcatFlashForCausalLM" + ], + "model_type": "longcat_flash", + "text_model_type": "longcat_flash", + "hidden_size": 6144, + "total_num_hidden_layers": 28, + "total_num_attention_heads": 64, + "head_size": 576, + "vocab_size": 131072, + "total_num_kv_heads": 64, + "num_experts": 512, + "is_deepseek_mla": true, + "is_multimodal_model": false, + "dtype": "torch.float32", + "dtype_original_type": "torch.dtype" + } +} diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py new file mode 100644 index 0000000000000..2900cd977efab --- /dev/null +++ b/tests/config/test_model_arch_config.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from pathlib import Path + +import torch + +from vllm.config import ModelConfig + + +def test_model_arch_config(): + trust_remote_code_models = [ + "nvidia/Llama-3_3-Nemotron-Super-49B-v1", + "XiaomiMiMo/MiMo-7B-RL", + # Not available online right now + # "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", + "meituan-longcat/LongCat-Flash-Chat", + ] + models_to_test = [ + "Zyphra/Zamba2-7B-instruct", + "mosaicml/mpt-7b", + "databricks/dbrx-instruct", + "tiiuae/falcon-7b", + "tiiuae/falcon-40b", + "luccafong/deepseek_mtp_main_random", + "luccafong/deepseek_mtp_draft_random", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "tiny-random/qwen3-next-moe", + "zai-org/GLM-4.5", + "baidu/ERNIE-4.5-21B-A3B-PT", + # Select some models using base convertor for testing + "lmsys/gpt-oss-20b-bf16", + "deepseek-ai/DeepSeek-V3.2-Exp", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + ] + trust_remote_code_models + + groundtruth_path = Path(__file__).parent / "model_arch_groundtruth.json" + with open(groundtruth_path) as f: + model_arch_groundtruth = json.load(f) + + for model in models_to_test: + print(f"testing {model=}") + model_config = ModelConfig( + model, trust_remote_code=model in trust_remote_code_models + ) + + model_arch_config = model_config.model_arch_config + expected = model_arch_groundtruth[model] + assert model_arch_config.architectures == expected["architectures"] + assert model_arch_config.model_type == expected["model_type"] + assert model_arch_config.text_model_type == expected["text_model_type"] + assert model_arch_config.hidden_size == expected["hidden_size"] + assert ( + model_arch_config.total_num_hidden_layers + == expected["total_num_hidden_layers"] + ) + assert ( + model_arch_config.total_num_attention_heads + == expected["total_num_attention_heads"] + ) + assert model_arch_config.head_size == expected["head_size"] + assert model_arch_config.vocab_size == expected["vocab_size"] + assert model_arch_config.total_num_kv_heads == expected["total_num_kv_heads"] + assert model_arch_config.num_experts == expected["num_experts"] + assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] + assert model_arch_config.is_multimodal_model == expected["is_multimodal_model"] + + dtype = model_arch_config.torch_dtype + assert str(dtype) == expected["dtype"] + if expected["dtype_original_type"] == "str": + assert isinstance(dtype, str) + elif expected["dtype_original_type"] == "torch.dtype": + assert isinstance(dtype, torch.dtype) + else: + raise ValueError(f"Unknown dtype_original_type: {expected['dtype']}") + + # Test that model_config methods return expected values + assert model_config.architectures == expected["architectures"] + assert model_config.get_vocab_size() == expected["vocab_size"] + assert model_config.get_hidden_size() == expected["hidden_size"] + assert model_config.get_head_size() == expected["head_size"] + assert model_config.get_total_num_kv_heads() == expected["total_num_kv_heads"] + assert model_config.get_num_experts() == expected["num_experts"] + assert ( + model_config.get_total_num_hidden_layers() + == expected["total_num_hidden_layers"] + ) diff --git a/tests/test_config.py b/tests/test_config.py index 77d3a7115978e..65807a17e9507 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import logging import os from dataclasses import MISSING, Field, asdict, dataclass, field diff --git a/vllm/config/model.py b/vllm/config/model.py index 764bdf7000561..7f6313005737b 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -10,15 +10,17 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass -from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers.configuration_utils import ALLOWED_LAYER_TYPES import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config.model_arch import ( + ModelArchitectureConfig, +) from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType -from vllm.config.utils import config, getattr_iter +from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -31,7 +33,6 @@ from vllm.transformers_utils.config import ( is_encoder_decoder, try_get_dense_modules, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope, uses_xdrope_dim, @@ -42,10 +43,13 @@ from vllm.transformers_utils.gguf_utils import ( maybe_patch_hf_config_from_gguf, split_remote_gguf, ) +from vllm.transformers_utils.model_arch_config_convertor import ( + MODEL_ARCH_CONFIG_CONVERTORS, + ModelArchConfigConvertorBase, +) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.import_utils import LazyLoader -from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig @@ -504,6 +508,12 @@ class ModelConfig: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) + self.model_arch_config = None + convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( + hf_config.model_type, ModelArchConfigConvertorBase + ) + convertor = convertor_cls(hf_config) + self.model_arch_config = convertor.convert(self.model, self.revision) architectures = self.architectures registry = self.registry @@ -765,7 +775,7 @@ class ModelConfig: @property def architectures(self) -> list[str]: - return getattr(self.hf_config, "architectures", []) + return self.model_arch_config.architectures @property def architecture(self) -> str: @@ -934,50 +944,16 @@ class ModelConfig: return "embed" - def _parse_quant_hf_config(self, hf_config: PretrainedConfig): - quant_cfg = getattr(hf_config, "quantization_config", None) - if quant_cfg is None: - # compressed-tensors uses a "compression_config" key - quant_cfg = getattr(hf_config, "compression_config", None) - - else: - # Set quant_method for ModelOpt models. - producer_name = quant_cfg.get("producer", {}).get("name") - if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") - if quant_algo == "FP8": - quant_cfg["quant_method"] = "modelopt" - elif quant_algo == "NVFP4": - quant_cfg["quant_method"] = "modelopt_fp4" - elif quant_algo is not None: - raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") - - return quant_cfg - def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config(self.hf_config) - if quant_cfg is None and ( - text_config := getattr(self.hf_config, "text_config", None) - ): - # Check the text config as well for multi-modal models. - quant_cfg = self._parse_quant_hf_config(text_config) + quant_cfg = ModelArchConfigConvertorBase.get_quantization_config(self.hf_config) if quant_cfg is not None: - # Use the community standard 'quant_method' - quant_method = quant_cfg.get("quant_method", "").lower() - - # Normalize library names - quant_method = quant_method.replace( - "compressed_tensors", "compressed-tensors" - ) - - quant_cfg["quant_method"] = quant_method - + quant_method = quant_cfg["quant_method"] # Quantization methods which are overrides (i.e. they have a # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). @@ -1059,7 +1035,7 @@ class ModelConfig: logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " "to eager mode.", - self.hf_config.model_type, + self.model_arch_config.model_type, ) self.enforce_eager = True @@ -1070,11 +1046,9 @@ class ModelConfig: # TODO Remove this when bitsandbytes supports. """ is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = ( - getattr(self.hf_config, "quantization_config", None) is not None - ) + has_quantization_config = self.model_arch_config.quantization_config is not None is_8bit = ( - self.hf_config.quantization_config.get("load_in_8bit", False) + self.model_arch_config.quantization_config.get("load_in_8bit", False) if has_quantization_config else False ) @@ -1128,9 +1102,7 @@ class ModelConfig: self, parallel_config: ParallelConfig, ) -> None: - total_num_attention_heads = getattr( - self.hf_text_config, "num_attention_heads", 0 - ) + total_num_attention_heads = self.model_arch_config.total_num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( @@ -1181,10 +1153,10 @@ class ModelConfig: return getattr(self.hf_text_config, "sliding_window", None) def get_vocab_size(self) -> int: - return getattr(self.hf_text_config, "vocab_size", 0) + return self.model_arch_config.vocab_size def get_hidden_size(self) -> int: - return getattr(self.hf_text_config, "hidden_size", 0) + return self.model_arch_config.hidden_size def get_inputs_embeds_size(self) -> int: # The size of inputs_embeds is usually identical to the size @@ -1198,29 +1170,7 @@ class ModelConfig: @property def is_deepseek_mla(self) -> bool: - if not hasattr(self.hf_text_config, "model_type"): - return False - elif self.hf_text_config.model_type in ( - "deepseek_v2", - "deepseek_v3", - "deepseek_v32", - "deepseek_mtp", - "kimi_k2", - "kimi_linear", - "longcat_flash", - "pangu_ultra_moe", - "pangu_ultra_moe_mtp", - ): - return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == "eagle": - # if the model is an EAGLE module, check for the - # underlying architecture - return ( - self.hf_text_config.model.model_type - in ("deepseek_v2", "deepseek_v3", "deepseek_v32") - and self.hf_text_config.kv_lora_rank is not None - ) - return False + return self.model_arch_config.is_deepseek_mla @cached_property def is_mm_prefix_lm(self) -> bool: @@ -1236,97 +1186,16 @@ class ModelConfig: return self.hf_config.model_type in MM_PREFIX_LM_MODELS def get_head_size(self) -> int: - # TODO remove hard code - if self.is_deepseek_mla: - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) - if self.use_mla: - return self.hf_text_config.kv_lora_rank + qk_rope_head_dim - else: - qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim - - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): - return self.hf_text_config.attention_head_dim - if self.is_attention_free: return 0 - - # NOTE: Some configs may set head_dim=None in the config - if getattr(self.hf_text_config, "head_dim", None) is not None: - return self.hf_text_config.head_dim - - # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` - if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: - return self.hf_text_config.hidden_size_per_head - - # FIXME(woosuk): This may not be true for all models. - return ( - self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads - ) + return self.model_arch_config.head_size def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" - # For GPTBigCode & Falcon: - # NOTE: for falcon, when new_decoder_architecture is True, the - # multi_query flag is ignored and we use n_head_kv for the number of - # KV heads. - falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False) - ) - if not new_decoder_arch_falcon and getattr( - self.hf_text_config, "multi_query", False - ): - # Multi-query attention, only one KV head. - # Currently, tensor parallelism is not supported in this case. - return 1 - - # For DBRX and MPT - if self.hf_config.model_type == "mpt": - if "kv_n_heads" in self.hf_config.attn_config: - return self.hf_config.attn_config["kv_n_heads"] - return self.hf_config.num_attention_heads - if self.hf_config.model_type == "dbrx": - return getattr( - self.hf_config.attn_config, - "kv_n_heads", - self.hf_config.num_attention_heads, - ) - - if self.hf_config.model_type == "nemotron-nas": - for block in self.hf_config.block_configs: - if not block.attention.no_op: - return ( - self.hf_config.num_attention_heads - // block.attention.n_heads_in_group - ) - - raise RuntimeError("Couldn't determine number of kv heads") - if self.is_attention_free: return 0 - attributes = [ - # For Falcon: - "n_head_kv", - "num_kv_heads", - # For LLaMA-2: - "num_key_value_heads", - # For ChatGLM: - "multi_query_group_num", - ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - - # For non-grouped-query attention models, the number of KV heads is - # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads + return self.model_arch_config.total_num_kv_heads def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: """Returns the number of KV heads per GPU.""" @@ -1342,46 +1211,14 @@ class ModelConfig: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: - num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + num_heads = self.model_arch_config.total_num_attention_heads return num_heads // parallel_config.tensor_parallel_size def get_num_experts(self) -> int: - """Returns the number of experts in the model.""" - num_expert_names = [ - "num_experts", # Jamba - "moe_num_experts", # Dbrx - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) - if isinstance(num_experts, list): - # Ernie VL's remote code uses list[int]... - # The values are always the same so we just take the first one. - return num_experts[0] - # Coerce to 0 if explicitly set to None - return num_experts or 0 + return self.model_arch_config.num_experts def get_total_num_hidden_layers(self) -> int: - if ( - self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp" - or self.hf_config.model_type == "qwen3_next_mtp" - or self.hf_config.model_type == "pangu_ultra_moe_mtp" - ): - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 0 - ) - elif self.hf_config.model_type == "longcat_flash_mtp": - total_num_hidden_layers = getattr( - self.hf_text_config, "num_nextn_predict_layers", 1 - ) - else: - total_num_hidden_layers = getattr( - self.hf_text_config, "num_hidden_layers", 0 - ) - return total_num_hidden_layers + return self.model_arch_config.total_num_hidden_layers def get_layers_start_end_indices( self, parallel_config: ParallelConfig @@ -1432,9 +1269,7 @@ class ModelConfig: self.hf_text_config, "layers_block_type", None ) if layers_block_type_value is not None: - if hasattr(self.hf_text_config, "model_type") and ( - self.hf_text_config.model_type == "zamba2" - ): + if self.model_arch_config.text_model_type == "zamba2": if attn_block_type: return sum( t == "hybrid" for t in layers_block_type_value[start:end] @@ -1745,6 +1580,7 @@ class ModelConfig: ) max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, + model_arch_config=self.model_arch_config, tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, @@ -1969,46 +1805,6 @@ def _check_valid_dtype(model_type: str, dtype: torch.dtype): return True -def _find_dtype( - model_id: str, - config: PretrainedConfig, - *, - revision: str | None, -): - # NOTE: getattr(config, "dtype", torch.float32) is not correct - # because config.dtype can be None. - config_dtype = getattr(config, "dtype", None) - - # Fallbacks for multi-modal models if the root config - # does not define dtype - if config_dtype is None: - config_dtype = getattr(config.get_text_config(), "dtype", None) - if config_dtype is None and hasattr(config, "vision_config"): - config_dtype = getattr(config.vision_config, "dtype", None) - if config_dtype is None and hasattr(config, "encoder_config"): - config_dtype = getattr(config.encoder_config, "dtype", None) - - # Try to read the dtype of the weights if they are in safetensors format - if config_dtype is None: - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) - - if repo_mt and (files_mt := repo_mt.files_metadata): - param_dtypes: set[torch.dtype] = { - _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] - for file_mt in files_mt.values() - for dtype_str in file_mt.parameter_count - if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE - } - - if param_dtypes: - return common_broadcastable_dtype(param_dtypes) - - if config_dtype is None: - config_dtype = torch.float32 - - return config_dtype - - def _resolve_auto_dtype( model_type: str, config_dtype: torch.dtype, @@ -2063,7 +1859,9 @@ def _get_and_verify_dtype( is_pooling_model: bool, revision: str | None = None, ) -> torch.dtype: - config_dtype = _find_dtype(model_id, config, revision=revision) + config_dtype = ModelArchConfigConvertorBase.get_torch_dtype( + config, model_id, revision=revision + ) model_type = config.model_type if isinstance(dtype, str): @@ -2126,6 +1924,7 @@ def _get_head_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, + model_arch_config: ModelArchitectureConfig, tokenizer_config: dict | None, max_model_len: int | None, disable_sliding_window: bool, @@ -2134,36 +1933,9 @@ def _get_and_verify_max_len( encoder_config: Any | None = None, ) -> int: """Get and verify the model's maximum length.""" - derived_max_model_len = float("inf") - possible_keys = [ - # OPT - "max_position_embeddings", - # GPT-2 - "n_positions", - # MPT - "max_seq_len", - # ChatGLM2 - "seq_length", - # Command-R - "model_max_length", - # Whisper - "max_target_positions", - # Others - "max_sequence_length", - "max_seq_length", - "seq_len", - ] - # Choose the smallest "max_length" from the possible keys - max_len_key = None - for key in possible_keys: - max_len = getattr(hf_config, key, None) - if max_len is not None: - max_len_key = key if max_len < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, max_len) - # For Command-R / Cohere, Cohere2 / Aya Vision models - if tmp_max_len := getattr(hf_config, "model_max_length", None): - max_len_key = "model_max_length" - derived_max_model_len = tmp_max_len + (derived_max_model_len, max_len_key) = ( + model_arch_config.derived_max_model_len_and_key + ) # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. @@ -2196,10 +1968,9 @@ def _get_and_verify_max_len( default_max_len = 2048 logger.warning( - "The model's config.json does not contain any of the following " - "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", - possible_keys, + "The model's config.json does not contain any of the keys " + "to determine the original maximum length of the model. " + "Assuming the model's maximum length is %d.", default_max_len, ) derived_max_model_len = default_max_len diff --git a/vllm/config/model_arch.py b/vllm/config/model_arch.py new file mode 100644 index 0000000000000..a35288c2330bf --- /dev/null +++ b/vllm/config/model_arch.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class ModelArchitectureConfig: + """ + Configuration for model architecture that required by vLLM runtime + """ + + architectures: list[str] + """List of model architecture class names (e.g., ['LlamaForCausalLM']).""" + + model_type: str + """Model type identifier (e.g., 'llama', 'gpt_oss').""" + + text_model_type: str | None + """Text model type identifier (e.g., 'llama4_text').""" + + hidden_size: int + """Hidden size of the model.""" + + total_num_hidden_layers: int + """Number of hidden layers in the model.""" + + total_num_attention_heads: int + """Number of attention heads in the model.""" + + head_size: int + """Head dimension of the model.""" + + vocab_size: int + """Vocabulary size of the model.""" + + total_num_kv_heads: int + """Number of key value heads in the model.""" + + num_experts: int + """Number of experts in the model.""" + + quantization_config: dict[str, Any] | None + """Quantization configuration dictionary containing quantization parameters.""" + + torch_dtype: torch.dtype | str | None + """PyTorch data type for model weights (e.g., 'float16', 'bfloat16').""" + + is_multimodal_model: bool + """Whether the model is a multimodal model.""" + + is_deepseek_mla: bool + """Whether the model is a DeepSeek MLA model.""" + + derived_max_model_len_and_key: tuple[float, str | None] + """Derived maximum model length and key from the hf config.""" diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py new file mode 100644 index 0000000000000..db5899eab5c12 --- /dev/null +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -0,0 +1,374 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING + +import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig + +from vllm import envs +from vllm.config.model_arch import ( + ModelArchitectureConfig, +) +from vllm.config.utils import getattr_iter +from vllm.logger import init_logger +from vllm.transformers_utils.config import ( + get_hf_text_config, + try_get_safetensors_metadata, +) +from vllm.utils.import_utils import LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype + +if TYPE_CHECKING: + import vllm.model_executor.models.registry as me_models_registry +else: + # Use lazy loading to avoid circular import + me_models_registry = LazyLoader( + "model_executor", globals(), "vllm.model_executor.models.registry" + ) + +logger = init_logger(__name__) + + +class ModelArchConfigConvertorBase: + def __init__(self, hf_config: PretrainedConfig): + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(hf_config) + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_hidden_layers", 0) + + def get_total_num_attention_heads(self) -> int: + return getattr(self.hf_text_config, "num_attention_heads", 0) + + def get_vocab_size(self) -> int: + return getattr(self.hf_text_config, "vocab_size", 0) + + def get_hidden_size(self) -> int: + return getattr(self.hf_text_config, "hidden_size", 0) + + def get_head_size(self) -> int: + if self.is_deepseek_mla(): + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) + if not envs.VLLM_MLA_DISABLE: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: + return self.hf_text_config.hidden_size_per_head + + # FIXME(woosuk): This may not be true for all models. + return ( + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads + ) + + def get_total_num_kv_heads(self) -> int: + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + return self.hf_text_config.num_attention_heads + + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "num_experts", # Jamba + "moe_num_experts", # Dbrx + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + # Coerce to 0 if explicitly set to None + return num_experts or 0 + + @classmethod + def get_torch_dtype(cls, hf_config, model_id: str, revision: str | None): + # NOTE: getattr(config, "dtype", torch.float32) is not correct + # because config.dtype can be None. + config_dtype = getattr(hf_config, "dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define dtype + if config_dtype is None: + config_dtype = getattr(hf_config.get_text_config(), "dtype", None) + if config_dtype is None and hasattr(hf_config, "vision_config"): + config_dtype = getattr(hf_config.vision_config, "dtype", None) + if config_dtype is None and hasattr(hf_config, "encoder_config"): + config_dtype = getattr(hf_config.encoder_config, "dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + @classmethod + def _normalize_quantization_config(cls, config: PretrainedConfig): + quant_cfg = getattr(config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(config, "compression_config", None) + + else: + # Set quant_method for ModelOpt models. + producer_name = quant_cfg.get("producer", {}).get("name") + if producer_name == "modelopt": + quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") + if quant_algo == "FP8": + quant_cfg["quant_method"] = "modelopt" + elif quant_algo == "NVFP4": + quant_cfg["quant_method"] = "modelopt_fp4" + elif quant_algo is not None: + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") + + if quant_cfg is not None: + # Use the community standard 'quant_method' + quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names + quant_method = quant_method.replace( + "compressed_tensors", "compressed-tensors" + ) + + quant_cfg["quant_method"] = quant_method + + return quant_cfg + + @classmethod + def get_quantization_config(cls, hf_config: PretrainedConfig): + quant_cfg = cls._normalize_quantization_config(hf_config) + if quant_cfg is None and ( + text_config := getattr(hf_config, "text_config", None) + ): + # Check the text config as well for multi-modal models. + quant_cfg = cls._normalize_quantization_config(text_config) + return quant_cfg + + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in ( + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "deepseek_mtp", + "kimi_k2", + "kimi_linear", + "longcat_flash", + "pangu_ultra_moe", + "pangu_ultra_moe_mtp", + ): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == "eagle": + # if the model is an EAGLE module, check for the + # underlying architecture + return ( + self.hf_text_config.model.model_type + in ("deepseek_v2", "deepseek_v3", "deepseek_v32") + and self.hf_text_config.kv_lora_rank is not None + ) + return False + + def derive_max_model_len_and_key(self) -> tuple[float, str | None]: + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys + max_len_key = None + for key in possible_keys: + max_len = getattr(self.hf_text_config, key, None) + if max_len is not None: + if max_len < derived_max_model_len: + max_len_key = key + derived_max_model_len = min(derived_max_model_len, max_len) + + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(self.hf_text_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + return derived_max_model_len, max_len_key + + def is_multimodal_model(self) -> bool: + return any( + multi_model_arch in self.hf_config.architectures + for multi_model_arch in me_models_registry._MULTIMODAL_MODELS + ) + + def convert(self, model_id: str, revision: str | None) -> ModelArchitectureConfig: + model_arch_config = ModelArchitectureConfig( + architectures=getattr(self.hf_config, "architectures", []), + model_type=self.hf_config.model_type, + text_model_type=getattr(self.hf_text_config, "model_type", None), + hidden_size=self.get_hidden_size(), + total_num_hidden_layers=self.get_num_hidden_layers(), + total_num_attention_heads=self.get_total_num_attention_heads(), + head_size=self.get_head_size(), + vocab_size=self.get_vocab_size(), + total_num_kv_heads=self.get_total_num_kv_heads(), + num_experts=self.get_num_experts(), + quantization_config=self.get_quantization_config(self.hf_config), + torch_dtype=self.get_torch_dtype(self.hf_config, model_id, revision), + is_multimodal_model=self.is_multimodal_model(), + is_deepseek_mla=self.is_deepseek_mla(), + derived_max_model_len_and_key=self.derive_max_model_len_and_key(), + ) + + return model_arch_config + + +class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + return getattr(self.hf_text_config, "attention_head_dim", 0) + + +class FalconModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + new_decoder_arch_falcon = getattr( + self.hf_text_config, "new_decoder_architecture", False + ) + + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): + # Multi-query attention, only one KV head. + return 1 + + # Use the base implementation which checks n_head_kv, num_kv_heads, etc. + return super().get_total_num_kv_heads() + + +class MPTModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + if "kv_n_heads" in self.hf_text_config.attn_config: + return self.hf_text_config.attn_config["kv_n_heads"] + return self.hf_text_config.num_attention_heads + + +class DbrxModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + return getattr( + self.hf_text_config.attn_config, + "kv_n_heads", + self.hf_text_config.num_attention_heads, + ) + + +class NemotronNasModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_kv_heads(self) -> int: + for block in self.hf_text_config.block_configs: + if not block.attention.no_op: + return ( + self.hf_text_config.num_attention_heads + // block.attention.n_heads_in_group + ) + raise RuntimeError("Couldn't determine number of kv heads") + + +class DeepSeekMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class MimoMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class GLM4MoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class ErnieMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class Qwen3NextMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class PanguUltraMoeMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 0) + + +class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) + + +# hf_config.model_type -> convertor class +MODEL_ARCH_CONFIG_CONVERTORS = { + "zamba2": Zamba2ModelArchConfigConvertor, + "mpt": MPTModelArchConfigConvertor, + "dbrx": DbrxModelArchConfigConvertor, + "falcon": FalconModelArchConfigConvertor, + "RefinedWeb": FalconModelArchConfigConvertor, + "RefinedWebModel": FalconModelArchConfigConvertor, + "nemotron-nas": NemotronNasModelArchConfigConvertor, + "deepseek_mtp": DeepSeekMTPModelArchConfigConvertor, + "qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor, + "mimo_mtp": MimoMTPModelArchConfigConvertor, + "glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor, + "ernie_mtp": ErnieMTPModelArchConfigConvertor, + "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, + "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, +}