mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 15:37:01 +08:00
[Model] Introduce verify_and_update_model_config for VerifyAndUpdateConfig. (#31131)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
b41aeb3468
commit
bd89ce16d2
@ -595,7 +595,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Avoid running try_verify_and_update_config multiple times
|
# Avoid running try_verify_and_update_config multiple times
|
||||||
self.config_updated = False
|
self.config_updated = False
|
||||||
|
self._try_verify_and_update_model_config()
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
self._verify_bnb_config()
|
self._verify_bnb_config()
|
||||||
@ -1008,6 +1008,23 @@ class ModelConfig:
|
|||||||
"when expert parallelism is enabled."
|
"when expert parallelism is enabled."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _try_verify_and_update_model_config(self):
|
||||||
|
# Avoid running try_verify_and_update_config multiple times
|
||||||
|
if getattr(self, "config_updated", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
architecture = self.architecture
|
||||||
|
if architecture is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm.model_executor.models.config import (
|
||||||
|
MODELS_CONFIG_MAP,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
||||||
|
if cls is not None:
|
||||||
|
cls.verify_and_update_model_config(self)
|
||||||
|
|
||||||
def verify_dual_chunk_attention_config(
|
def verify_dual_chunk_attention_config(
|
||||||
self,
|
self,
|
||||||
load_config: LoadConfig,
|
load_config: LoadConfig,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
|||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -21,20 +21,24 @@ logger = init_logger(__name__)
|
|||||||
class VerifyAndUpdateConfig:
|
class VerifyAndUpdateConfig:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
raise NotImplementedError
|
return
|
||||||
|
|
||||||
|
|
||||||
class Gemma3TextModelConfig:
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
hf_config = vllm_config.model_config.hf_config
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||||
|
@staticmethod
|
||||||
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
|
hf_config = model_config.hf_config
|
||||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||||
|
|
||||||
|
|
||||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = model_config.hf_config
|
||||||
|
|
||||||
assert config.__class__.__name__ == "NewConfig"
|
assert config.__class__.__name__ == "NewConfig"
|
||||||
assert config.hidden_act == "gelu"
|
assert config.hidden_act == "gelu"
|
||||||
@ -53,16 +57,15 @@ class GteNewModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = model_config.pooler_config
|
||||||
if pooler_config.use_activation is None:
|
if pooler_config.use_activation is None:
|
||||||
pooler_config.use_activation = False
|
pooler_config.use_activation = False
|
||||||
|
|
||||||
|
|
||||||
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
model_config = vllm_config.model_config
|
|
||||||
config = model_config.hf_config
|
config = model_config.hf_config
|
||||||
|
|
||||||
if config.position_embedding_type == "rotary":
|
if config.position_embedding_type == "rotary":
|
||||||
@ -90,10 +93,10 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
from vllm.config.pooler import PoolingTypeStr
|
from vllm.config.pooler import PoolingTypeStr
|
||||||
|
|
||||||
hf_config = vllm_config.model_config.hf_config
|
hf_config = model_config.hf_config
|
||||||
hf_config.is_causal = False
|
hf_config.is_causal = False
|
||||||
|
|
||||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||||
@ -105,7 +108,7 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
|||||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||||
if pooling_type is None:
|
if pooling_type is None:
|
||||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
model_config.pooler_config.pooling_type = pooling_type
|
||||||
|
|
||||||
|
|
||||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||||
@ -204,8 +207,8 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = model_config.pooler_config
|
||||||
|
|
||||||
if pooler_config.step_tag_id is None:
|
if pooler_config.step_tag_id is None:
|
||||||
pooler_config.step_tag_id = 151651
|
pooler_config.step_tag_id = 151651
|
||||||
@ -213,8 +216,8 @@ class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = model_config.pooler_config
|
||||||
|
|
||||||
if pooler_config.softmax is None:
|
if pooler_config.softmax is None:
|
||||||
pooler_config.softmax = False
|
pooler_config.softmax = False
|
||||||
@ -222,8 +225,8 @@ class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig):
|
|||||||
|
|
||||||
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = model_config.hf_config
|
||||||
|
|
||||||
is_original_qwen3_reranker = getattr(
|
is_original_qwen3_reranker = getattr(
|
||||||
config, "is_original_qwen3_reranker", False
|
config, "is_original_qwen3_reranker", False
|
||||||
@ -237,23 +240,23 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
|||||||
"Try loading the original Qwen3 Reranker?, see: "
|
"Try loading the original Qwen3 Reranker?, see: "
|
||||||
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
|
||||||
)
|
)
|
||||||
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
model_config.hf_config.method = "from_2_way_softmax"
|
||||||
|
|
||||||
|
|
||||||
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = model_config.hf_config
|
||||||
config.num_labels = 1
|
config.num_labels = 1
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = model_config.pooler_config
|
||||||
if pooler_config.logit_bias is None:
|
if pooler_config.logit_bias is None:
|
||||||
pooler_config.logit_bias = 2.65
|
pooler_config.logit_bias = 2.65
|
||||||
|
|
||||||
|
|
||||||
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = model_config.hf_config
|
||||||
|
|
||||||
assert config.__class__.__name__ == "GteConfig"
|
assert config.__class__.__name__ == "GteConfig"
|
||||||
assert config.hidden_act == "gelu"
|
assert config.hidden_act == "gelu"
|
||||||
|
|||||||
@ -64,7 +64,6 @@ from .interfaces import (
|
|||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
from .interfaces_base import attn_type
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@ -707,14 +706,12 @@ class LlamaForCausalLM(
|
|||||||
return name, loaded_weight
|
return name, loaded_weight
|
||||||
|
|
||||||
|
|
||||||
@attn_type("encoder_only")
|
|
||||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||||
# This class sets the correct attention type and pooling type
|
# This class sets the correct attention type and pooling type
|
||||||
# through LlamaBidirectionalConfig.
|
# through LlamaBidirectionalConfig.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@attn_type("encoder_only")
|
|
||||||
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
||||||
# This class sets the correct attention type and pooling type
|
# This class sets the correct attention type and pooling type
|
||||||
# through LlamaBidirectionalConfig.
|
# through LlamaBidirectionalConfig.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user