mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 02:21:48 +08:00
[Model] Introduce verify_and_update_model_config for VerifyAndUpdateConfig. (#31131)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
b41aeb3468
commit
bd89ce16d2
@ -595,7 +595,7 @@ class ModelConfig:
|
||||
|
||||
# Avoid running try_verify_and_update_config multiple times
|
||||
self.config_updated = False
|
||||
|
||||
self._try_verify_and_update_model_config()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
@ -1008,6 +1008,23 @@ class ModelConfig:
|
||||
"when expert parallelism is enabled."
|
||||
)
|
||||
|
||||
def _try_verify_and_update_model_config(self):
|
||||
# Avoid running try_verify_and_update_config multiple times
|
||||
if getattr(self, "config_updated", False):
|
||||
return
|
||||
|
||||
architecture = self.architecture
|
||||
if architecture is None:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.config import (
|
||||
MODELS_CONFIG_MAP,
|
||||
)
|
||||
|
||||
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
||||
if cls is not None:
|
||||
cls.verify_and_update_model_config(self)
|
||||
|
||||
def verify_dual_chunk_attention_config(
|
||||
self,
|
||||
load_config: LoadConfig,
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -21,20 +21,24 @@ logger = init_logger(__name__)
|
||||
class VerifyAndUpdateConfig:
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
raise NotImplementedError
|
||||
return
|
||||
|
||||
|
||||
class Gemma3TextModelConfig:
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
return
|
||||
|
||||
|
||||
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "NewConfig"
|
||||
assert config.hidden_act == "gelu"
|
||||
@ -53,16 +57,15 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config.use_activation is None:
|
||||
pooler_config.use_activation = False
|
||||
|
||||
|
||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
model_config = vllm_config.model_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
if config.position_embedding_type == "rotary":
|
||||
@ -90,10 +93,10 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config = model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
@ -105,7 +108,7 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
||||
model_config.pooler_config.pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@ -204,8 +207,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config.step_tag_id is None:
|
||||
pooler_config.step_tag_id = 151651
|
||||
@ -213,8 +216,8 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config.softmax is None:
|
||||
pooler_config.softmax = False
|
||||
@ -222,8 +225,8 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
is_original_qwen3_reranker = getattr(
|
||||
config, "is_original_qwen3_reranker", False
|
||||
@ -237,23 +240,23 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
"Try loading the original Qwen3 Reranker?, see: "
|
||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
|
||||
)
|
||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
||||
model_config.hf_config.method = "from_2_way_softmax"
|
||||
|
||||
|
||||
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
config.num_labels = 1
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config.logit_bias is None:
|
||||
pooler_config.logit_bias = 2.65
|
||||
|
||||
|
||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||
config = model_config.hf_config
|
||||
|
||||
assert config.__class__.__name__ == "GteConfig"
|
||||
assert config.hidden_act == "gelu"
|
||||
|
||||
@ -64,7 +64,6 @@ from .interfaces import (
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .interfaces_base import attn_type
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -707,14 +706,12 @@ class LlamaForCausalLM(
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user