[Model] Introduce verify_and_update_model_config for VerifyAndUpdateConfig. (#31131)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-12-24 17:54:57 +08:00 committed by GitHub
parent b41aeb3468
commit bd89ce16d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 31 deletions

View File

@ -595,7 +595,7 @@ class ModelConfig:
# Avoid running try_verify_and_update_config multiple times
self.config_updated = False
self._try_verify_and_update_model_config()
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@ -1008,6 +1008,23 @@ class ModelConfig:
"when expert parallelism is enabled."
)
def _try_verify_and_update_model_config(self):
# Avoid running try_verify_and_update_config multiple times
if getattr(self, "config_updated", False):
return
architecture = self.architecture
if architecture is None:
return
from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP,
)
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_model_config(self)
def verify_dual_chunk_attention_config(
self,
load_config: LoadConfig,

View File

@ -13,7 +13,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
@ -21,20 +21,24 @@ logger = init_logger(__name__)
class VerifyAndUpdateConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
raise NotImplementedError
return
class Gemma3TextModelConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
hf_config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
return
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"
@ -53,16 +57,15 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.use_activation is None:
pooler_config.use_activation = False
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
model_config = vllm_config.model_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
if config.position_embedding_type == "rotary":
@ -90,10 +93,10 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
from vllm.config.pooler import PoolingTypeStr
hf_config = vllm_config.model_config.hf_config
hf_config = model_config.hf_config
hf_config.is_causal = False
pooling_type_map: dict[str, PoolingTypeStr] = {
@ -105,7 +108,7 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type = pooling_type_map.get(hf_config.pooling, None)
if pooling_type is None:
raise ValueError(f"pool_type {hf_config.pooling} not supported")
vllm_config.model_config.pooler_config.pooling_type = pooling_type
model_config.pooler_config.pooling_type = pooling_type
class NomicBertModelConfig(VerifyAndUpdateConfig):
@ -204,8 +207,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.step_tag_id is None:
pooler_config.step_tag_id = 151651
@ -213,8 +216,8 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
pooler_config = vllm_config.model_config.pooler_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
pooler_config = model_config.pooler_config
if pooler_config.softmax is None:
pooler_config.softmax = False
@ -222,8 +225,8 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
is_original_qwen3_reranker = getattr(
config, "is_original_qwen3_reranker", False
@ -237,23 +240,23 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
)
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
model_config.hf_config.method = "from_2_way_softmax"
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
config.num_labels = 1
pooler_config = vllm_config.model_config.pooler_config
pooler_config = model_config.pooler_config
if pooler_config.logit_bias is None:
pooler_config.logit_bias = 2.65
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config.hf_config
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
config = model_config.hf_config
assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"

View File

@ -64,7 +64,6 @@ from .interfaces import (
SupportsLoRA,
SupportsPP,
)
from .interfaces_base import attn_type
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@ -707,14 +706,12 @@ class LlamaForCausalLM(
return name, loaded_weight
@attn_type("encoder_only")
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.
pass
@attn_type("encoder_only")
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
# This class sets the correct attention type and pooling type
# through LlamaBidirectionalConfig.