[Model] Supplement to PR 24862: Pass param prefix to LLMHead (#25805)

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
whx 2025-10-03 21:34:53 +08:00 committed by yewentao256
parent c6f384dafd
commit fac9b430ec
8 changed files with 35 additions and 12 deletions

View File

@ -28,13 +28,15 @@ class SharedHead(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size, self.head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states) return self.norm(hidden_states)
@ -64,7 +66,9 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
device="cuda") device="cuda")
else: else:
topk_indices_buffer = None topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config) self.shared_head = SharedHead(config=config,
prefix=prefix,
quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer) topk_indices_buffer)

View File

@ -50,13 +50,15 @@ class SharedHead(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size, self.head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states) return self.norm(hidden_states)
@ -77,7 +79,9 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
self.eh_proj = nn.Linear(config.hidden_size * 2, self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config) self.shared_head = SharedHead(config=config,
prefix=prefix,
quant_config=quant_config)
self.mtp_block = Glm4MoeDecoderLayer(config=config, self.mtp_block = Glm4MoeDecoderLayer(config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,

View File

@ -296,6 +296,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "embed_out"),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight self.embed_out.weight = self.gpt_neox.embed_in.weight

View File

@ -140,6 +140,7 @@ class LongCatFlashMTP(nn.Module, SupportsPP):
self.config.vocab_size, self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
) )
self.logits_processor = LogitsProcessor(self.config.vocab_size) self.logits_processor = LogitsProcessor(self.config.vocab_size)

View File

@ -82,7 +82,8 @@ class Medusa(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=self.truncated_vocab_size, org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads) prefix=maybe_prefix(prefix, f"lm_heads.{i}"),
) for i in range(self.config.num_heads)
]) ])
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)

View File

@ -13,6 +13,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .utils import maybe_prefix
SQRT2 = 2**0.5 SQRT2 = 2**0.5
@ -97,8 +99,13 @@ class MLPSpeculator(nn.Module):
self.proj = nn.ModuleList([proj_first] + [proj_tied] * self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1)) (self.max_speculative_tokens - 1))
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) self.head = nn.ModuleList([
self.head = nn.ModuleList([head] * self.max_speculative_tokens) ParallelLMHead(self.vocab_size,
self.inner_dim,
bias=False,
prefix=maybe_prefix(prefix, f"head.{i}"))
for i in range(self.max_speculative_tokens)
])
ln = MLPSpeculatorLayerNorm(self.inner_dim, ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True) elementwise_scale_and_shift=True)
@ -120,8 +127,11 @@ class MLPSpeculator(nn.Module):
]) ])
self.head = nn.ModuleList([ self.head = nn.ModuleList([
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) ParallelLMHead(self.vocab_size,
for _ in range(self.max_speculative_tokens) self.inner_dim,
bias=False,
prefix=maybe_prefix(prefix, f"head.{i}"))
for i in range(self.max_speculative_tokens)
]) ])
self.ln = nn.ModuleList([ self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim, MLPSpeculatorLayerNorm(self.inner_dim,

View File

@ -296,7 +296,8 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size, self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
quant_config=self.quant_config) quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.config.vocab_size) self.logits_processor = LogitsProcessor(self.config.vocab_size)

View File

@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription) SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers) make_layers, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -885,7 +885,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.proj_out = ParallelLMHead(config.vocab_size, self.proj_out = ParallelLMHead(config.vocab_size,
config.d_model, config.d_model,
quant_config=quant_config) quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj_out"))
self.proj_out = self.proj_out.tie_weights( self.proj_out = self.proj_out.tie_weights(
self.model.decoder.embed_tokens) self.model.decoder.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)