mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 20:15:43 +08:00
[Model] Supplement to PR 24862: Pass param prefix to LLMHead (#25805)
Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
c6f384dafd
commit
fac9b430ec
@ -28,13 +28,15 @@ class SharedHead(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
prefix: str,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.head = ParallelLMHead(config.vocab_size,
|
self.head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "head"))
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
return self.norm(hidden_states)
|
return self.norm(hidden_states)
|
||||||
@ -64,7 +66,9 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
else:
|
else:
|
||||||
topk_indices_buffer = None
|
topk_indices_buffer = None
|
||||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
self.shared_head = SharedHead(config=config,
|
||||||
|
prefix=prefix,
|
||||||
|
quant_config=quant_config)
|
||||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||||
topk_indices_buffer)
|
topk_indices_buffer)
|
||||||
|
|
||||||
|
|||||||
@ -50,13 +50,15 @@ class SharedHead(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
prefix: str,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.head = ParallelLMHead(config.vocab_size,
|
self.head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "head"))
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
return self.norm(hidden_states)
|
return self.norm(hidden_states)
|
||||||
@ -77,7 +79,9 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
|
|||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
self.shared_head = SharedHead(config=config,
|
||||||
|
prefix=prefix,
|
||||||
|
quant_config=quant_config)
|
||||||
self.mtp_block = Glm4MoeDecoderLayer(config=config,
|
self.mtp_block = Glm4MoeDecoderLayer(config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
|||||||
@ -296,6 +296,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "embed_out"),
|
||||||
)
|
)
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
self.embed_out.weight = self.gpt_neox.embed_in.weight
|
||||||
|
|||||||
@ -140,6 +140,7 @@ class LongCatFlashMTP(nn.Module, SupportsPP):
|
|||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,8 @@ class Medusa(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
org_num_embeddings=self.truncated_vocab_size,
|
org_num_embeddings=self.truncated_vocab_size,
|
||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
) for _ in range(self.config.num_heads)
|
prefix=maybe_prefix(prefix, f"lm_heads.{i}"),
|
||||||
|
) for i in range(self.config.num_heads)
|
||||||
])
|
])
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
SQRT2 = 2**0.5
|
SQRT2 = 2**0.5
|
||||||
|
|
||||||
|
|
||||||
@ -97,8 +99,13 @@ class MLPSpeculator(nn.Module):
|
|||||||
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||||
(self.max_speculative_tokens - 1))
|
(self.max_speculative_tokens - 1))
|
||||||
|
|
||||||
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
self.head = nn.ModuleList([
|
||||||
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
ParallelLMHead(self.vocab_size,
|
||||||
|
self.inner_dim,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, f"head.{i}"))
|
||||||
|
for i in range(self.max_speculative_tokens)
|
||||||
|
])
|
||||||
|
|
||||||
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||||
elementwise_scale_and_shift=True)
|
elementwise_scale_and_shift=True)
|
||||||
@ -120,8 +127,11 @@ class MLPSpeculator(nn.Module):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.head = nn.ModuleList([
|
self.head = nn.ModuleList([
|
||||||
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
ParallelLMHead(self.vocab_size,
|
||||||
for _ in range(self.max_speculative_tokens)
|
self.inner_dim,
|
||||||
|
bias=False,
|
||||||
|
prefix=maybe_prefix(prefix, f"head.{i}"))
|
||||||
|
for i in range(self.max_speculative_tokens)
|
||||||
])
|
])
|
||||||
self.ln = nn.ModuleList([
|
self.ln = nn.ModuleList([
|
||||||
MLPSpeculatorLayerNorm(self.inner_dim,
|
MLPSpeculatorLayerNorm(self.inner_dim,
|
||||||
|
|||||||
@ -296,7 +296,8 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
|||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
quant_config=self.quant_config)
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"))
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||||
|
|||||||
@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
SupportsTranscription)
|
SupportsTranscription)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||||
make_layers)
|
make_layers, maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -885,7 +885,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.proj_out = ParallelLMHead(config.vocab_size,
|
self.proj_out = ParallelLMHead(config.vocab_size,
|
||||||
config.d_model,
|
config.d_model,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "proj_out"))
|
||||||
self.proj_out = self.proj_out.tie_weights(
|
self.proj_out = self.proj_out.tie_weights(
|
||||||
self.model.decoder.embed_tokens)
|
self.model.decoder.embed_tokens)
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user