From fac9b430ec2a3070f81f1e79ff34ace57501f9b4 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Fri, 3 Oct 2025 21:34:53 +0800 Subject: [PATCH] [Model] Supplement to PR 24862: Pass param prefix to LLMHead (#25805) Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: yewentao256 --- vllm/model_executor/models/deepseek_mtp.py | 8 ++++++-- vllm/model_executor/models/glm4_moe_mtp.py | 8 ++++++-- vllm/model_executor/models/gpt_neox.py | 1 + .../model_executor/models/longcat_flash_mtp.py | 1 + vllm/model_executor/models/medusa.py | 3 ++- vllm/model_executor/models/mlp_speculator.py | 18 ++++++++++++++---- vllm/model_executor/models/qwen3_vl_moe.py | 3 ++- vllm/model_executor/models/whisper.py | 5 +++-- 8 files changed, 35 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 788e561ac394d..02a25ab762e59 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -28,13 +28,15 @@ class SharedHead(nn.Module): def __init__( self, config: PretrainedConfig, + prefix: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) @@ -64,7 +66,9 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): device="cuda") else: topk_indices_buffer = None - self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.shared_head = SharedHead(config=config, + prefix=prefix, + quant_config=quant_config) self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, topk_indices_buffer) diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 826d541e571bd..57b698e239eca 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -50,13 +50,15 @@ class SharedHead(nn.Module): def __init__( self, config: PretrainedConfig, + prefix: str, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) @@ -77,7 +79,9 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module): self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.shared_head = SharedHead(config=config, + prefix=prefix, + quant_config=quant_config) self.mtp_block = Glm4MoeDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 7570aefb6e96e..45519a94d854c 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -296,6 +296,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_out"), ) if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight diff --git a/vllm/model_executor/models/longcat_flash_mtp.py b/vllm/model_executor/models/longcat_flash_mtp.py index eebc2ee155979..e288658a7ebf3 100644 --- a/vllm/model_executor/models/longcat_flash_mtp.py +++ b/vllm/model_executor/models/longcat_flash_mtp.py @@ -140,6 +140,7 @@ class LongCatFlashMTP(nn.Module, SupportsPP): self.config.vocab_size, self.config.hidden_size, quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.config.vocab_size) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 0ae59dc8dfc23..f083c2cb0380a 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -82,7 +82,8 @@ class Medusa(nn.Module): config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) for _ in range(self.config.num_heads) + prefix=maybe_prefix(prefix, f"lm_heads.{i}"), + ) for i in range(self.config.num_heads) ]) logit_scale = getattr(config, "logit_scale", 1.0) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index d057eb49a62d1..0f375134ef00f 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -13,6 +13,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from .utils import maybe_prefix + SQRT2 = 2**0.5 @@ -97,8 +99,13 @@ class MLPSpeculator(nn.Module): self.proj = nn.ModuleList([proj_first] + [proj_tied] * (self.max_speculative_tokens - 1)) - head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) - self.head = nn.ModuleList([head] * self.max_speculative_tokens) + self.head = nn.ModuleList([ + ParallelLMHead(self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}")) + for i in range(self.max_speculative_tokens) + ]) ln = MLPSpeculatorLayerNorm(self.inner_dim, elementwise_scale_and_shift=True) @@ -120,8 +127,11 @@ class MLPSpeculator(nn.Module): ]) self.head = nn.ModuleList([ - ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) - for _ in range(self.max_speculative_tokens) + ParallelLMHead(self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}")) + for i in range(self.max_speculative_tokens) ]) self.ln = nn.ModuleList([ MLPSpeculatorLayerNorm(self.inner_dim, diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 1ed053eb2e96c..bd4aae7404c61 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -296,7 +296,8 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size, - quant_config=self.quant_config) + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head")) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.config.vocab_size) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d349d91dfd760..84686b8b19411 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - make_layers) + make_layers, maybe_prefix) logger = init_logger(__name__) @@ -885,7 +885,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, self.unpadded_vocab_size = config.vocab_size self.proj_out = ParallelLMHead(config.vocab_size, config.d_model, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix(prefix, "proj_out")) self.proj_out = self.proj_out.tie_weights( self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0)