From bd89ce16d216f33d93cb72d0a88b2a98d726784a Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Wed, 24 Dec 2025 17:54:57 +0800 Subject: [PATCH] [Model] Introduce verify_and_update_model_config for VerifyAndUpdateConfig. (#31131) Signed-off-by: wang.yuqi Signed-off-by: wang.yuqi --- vllm/config/model.py | 19 +++++++++- vllm/model_executor/models/config.py | 57 +++++++++++++++------------- vllm/model_executor/models/llama.py | 3 -- 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index c3e23de220949..ce554b136cef3 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -595,7 +595,7 @@ class ModelConfig: # Avoid running try_verify_and_update_config multiple times self.config_updated = False - + self._try_verify_and_update_model_config() self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -1008,6 +1008,23 @@ class ModelConfig: "when expert parallelism is enabled." ) + def _try_verify_and_update_model_config(self): + # Avoid running try_verify_and_update_config multiple times + if getattr(self, "config_updated", False): + return + + architecture = self.architecture + if architecture is None: + return + + from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, + ) + + cls = MODELS_CONFIG_MAP.get(architecture, None) + if cls is not None: + cls.verify_and_update_model_config(self) + def verify_dual_chunk_attention_config( self, load_config: LoadConfig, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index d33b3fdf47467..10fd599f9e5f8 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -13,7 +13,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig logger = init_logger(__name__) @@ -21,20 +21,24 @@ logger = init_logger(__name__) class VerifyAndUpdateConfig: @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - raise NotImplementedError + return - -class Gemma3TextModelConfig: @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - hf_config = vllm_config.model_config.hf_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + return + + +class Gemma3TextModelConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + hf_config = model_config.hf_config hf_config.is_causal = not hf_config.use_bidirectional_attention class GteNewModelConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + config = model_config.hf_config assert config.__class__.__name__ == "NewConfig" assert config.hidden_act == "gelu" @@ -53,16 +57,15 @@ class GteNewModelConfig(VerifyAndUpdateConfig): class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - pooler_config = vllm_config.model_config.pooler_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + pooler_config = model_config.pooler_config if pooler_config.use_activation is None: pooler_config.use_activation = False class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - model_config = vllm_config.model_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: config = model_config.hf_config if config.position_embedding_type == "rotary": @@ -90,10 +93,10 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig): class LlamaBidirectionalConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: + def verify_and_update_model_config(model_config: "ModelConfig") -> None: from vllm.config.pooler import PoolingTypeStr - hf_config = vllm_config.model_config.hf_config + hf_config = model_config.hf_config hf_config.is_causal = False pooling_type_map: dict[str, PoolingTypeStr] = { @@ -105,7 +108,7 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig): pooling_type = pooling_type_map.get(hf_config.pooling, None) if pooling_type is None: raise ValueError(f"pool_type {hf_config.pooling} not supported") - vllm_config.model_config.pooler_config.pooling_type = pooling_type + model_config.pooler_config.pooling_type = pooling_type class NomicBertModelConfig(VerifyAndUpdateConfig): @@ -204,8 +207,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - pooler_config = vllm_config.model_config.pooler_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + pooler_config = model_config.pooler_config if pooler_config.step_tag_id is None: pooler_config.step_tag_id = 151651 @@ -213,8 +216,8 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - pooler_config = vllm_config.model_config.pooler_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + pooler_config = model_config.pooler_config if pooler_config.softmax is None: pooler_config.softmax = False @@ -222,8 +225,8 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + config = model_config.hf_config is_original_qwen3_reranker = getattr( config, "is_original_qwen3_reranker", False @@ -237,23 +240,23 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): "Try loading the original Qwen3 Reranker?, see: " "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py" ) - vllm_config.model_config.hf_config.method = "from_2_way_softmax" + model_config.hf_config.method = "from_2_way_softmax" class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + config = model_config.hf_config config.num_labels = 1 - pooler_config = vllm_config.model_config.pooler_config + pooler_config = model_config.pooler_config if pooler_config.logit_bias is None: pooler_config.logit_bias = 2.65 class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): @staticmethod - def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + config = model_config.hf_config assert config.__class__.__name__ == "GteConfig" assert config.hidden_act == "gelu" diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 84f4211df4c20..f0f2983f84637 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -64,7 +64,6 @@ from .interfaces import ( SupportsLoRA, SupportsPP, ) -from .interfaces_base import attn_type from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -707,14 +706,12 @@ class LlamaForCausalLM( return name, loaded_weight -@attn_type("encoder_only") class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)): # This class sets the correct attention type and pooling type # through LlamaBidirectionalConfig. pass -@attn_type("encoder_only") class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)): # This class sets the correct attention type and pooling type # through LlamaBidirectionalConfig.