mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 18:16:01 +08:00
[Model] Supplement to PR 24862: Pass param prefix to LLMHead (#25805)
Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
c6f384dafd
commit
fac9b430ec
@ -28,13 +28,15 @@ class SharedHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
@ -64,7 +66,9 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.shared_head = SharedHead(config=config,
|
||||
prefix=prefix,
|
||||
quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
|
||||
@ -50,13 +50,15 @@ class SharedHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
@ -77,7 +79,9 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.shared_head = SharedHead(config=config,
|
||||
prefix=prefix,
|
||||
quant_config=quant_config)
|
||||
self.mtp_block = Glm4MoeDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
|
||||
@ -296,6 +296,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "embed_out"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
||||
|
||||
@ -140,6 +140,7 @@ class LongCatFlashMTP(nn.Module, SupportsPP):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
|
||||
@ -82,7 +82,8 @@ class Medusa(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
) for _ in range(self.config.num_heads)
|
||||
prefix=maybe_prefix(prefix, f"lm_heads.{i}"),
|
||||
) for i in range(self.config.num_heads)
|
||||
])
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
SQRT2 = 2**0.5
|
||||
|
||||
|
||||
@ -97,8 +99,13 @@ class MLPSpeculator(nn.Module):
|
||||
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||
(self.max_speculative_tokens - 1))
|
||||
|
||||
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
||||
self.head = nn.ModuleList([
|
||||
ParallelLMHead(self.vocab_size,
|
||||
self.inner_dim,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, f"head.{i}"))
|
||||
for i in range(self.max_speculative_tokens)
|
||||
])
|
||||
|
||||
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
elementwise_scale_and_shift=True)
|
||||
@ -120,8 +127,11 @@ class MLPSpeculator(nn.Module):
|
||||
])
|
||||
|
||||
self.head = nn.ModuleList([
|
||||
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
ParallelLMHead(self.vocab_size,
|
||||
self.inner_dim,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, f"head.{i}"))
|
||||
for i in range(self.max_speculative_tokens)
|
||||
])
|
||||
self.ln = nn.ModuleList([
|
||||
MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
|
||||
@ -296,7 +296,8 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=self.quant_config)
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||
make_layers)
|
||||
make_layers, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -885,7 +885,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.proj_out = ParallelLMHead(config.vocab_size,
|
||||
config.d_model,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "proj_out"))
|
||||
self.proj_out = self.proj_out.tie_weights(
|
||||
self.model.decoder.embed_tokens)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user