mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Model] Pass param prefix to LLMHead (#24862)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
parent
03191cd8f0
commit
4a9375fe9d
@ -427,6 +427,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -539,6 +539,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.language_model.org_vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
|
||||
@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@ -394,7 +395,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
position_embedding=position_embedding)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -514,6 +514,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
@ -330,7 +330,9 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -960,6 +960,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -438,6 +438,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -453,9 +453,12 @@ class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -199,7 +199,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
|
||||
@ -823,9 +823,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts,
|
||||
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -504,7 +504,9 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -562,7 +562,9 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
|
||||
@ -557,7 +557,9 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
|
||||
@ -158,7 +158,8 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.sampler = get_sampler()
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
|
||||
@ -502,6 +502,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
|
||||
@ -485,6 +485,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -473,6 +473,7 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -607,6 +607,7 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
# compatibility
|
||||
if not lora_config else
|
||||
lora_config.lora_vocab_padding_size),
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.lm_head_multiplier = config.lm_head_multiplier
|
||||
if self.tie_word_embeddings:
|
||||
|
||||
@ -608,7 +608,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -302,7 +302,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.transformer.vocab_size,
|
||||
self.transformer.embed_dim,
|
||||
org_num_embeddings=self.config.vocab_size)
|
||||
org_num_embeddings=self.config.vocab_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -306,6 +306,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
config.n_embd,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -655,6 +655,7 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -434,6 +434,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -487,6 +487,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -58,7 +58,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_layers)
|
||||
make_layers, maybe_prefix)
|
||||
|
||||
|
||||
def _is_moe(config: PretrainedConfig) -> bool:
|
||||
@ -871,6 +871,7 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -606,6 +606,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.text_config.vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.text_config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.text_model.wte.weight
|
||||
|
||||
@ -302,7 +302,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
if hasattr(config, "width_scale"):
|
||||
self.output_logits_scale = config.width_scale
|
||||
else:
|
||||
|
||||
@ -502,6 +502,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
@ -328,6 +328,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.text_config.hidden_size,
|
||||
org_num_embeddings=self.config.text_config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@ -220,7 +220,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
self.config.hidden_size,
|
||||
org_num_embeddings=self.config.draft_vocab_size,
|
||||
padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
|
||||
prefix="")
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||
scale=logit_scale)
|
||||
self.draft_id_to_target_id = nn.Parameter(
|
||||
|
||||
@ -223,6 +223,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
|
||||
@ -278,6 +278,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
||||
|
||||
@ -15,6 +15,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
@ -71,6 +73,7 @@ class Medusa(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=self.truncated_vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.lm_heads = [
|
||||
self.lm_head for _ in range(self.config.num_heads)
|
||||
|
||||
@ -158,7 +158,8 @@ class MiMoMTP(nn.Module):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -547,6 +547,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
@ -338,6 +338,7 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
@ -702,6 +702,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.config.hidden_size,
|
||||
org_num_embeddings=self.config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
|
||||
@ -507,6 +507,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -1403,6 +1403,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
config.embedding_size or config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.embedding_size
|
||||
|
||||
@ -466,6 +466,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -565,6 +565,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
@ -364,6 +364,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -450,7 +450,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -375,7 +375,9 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
self.lm_head = self.model.decoder.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.word_embed_proj_dim)
|
||||
config.word_embed_proj_dim,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -314,7 +314,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -307,7 +307,8 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
bias=False,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -322,7 +322,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -630,6 +630,7 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.embedding_bias = None
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
|
||||
@ -989,6 +989,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
@ -645,6 +645,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if not lora_config else lora_config.lora_vocab_padding_size),
|
||||
quant_config=None,
|
||||
bias=True,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -271,7 +271,8 @@ class QWenBaseModel(nn.Module):
|
||||
prefix, "transformer"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -519,7 +519,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -605,7 +605,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
@ -1089,7 +1089,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -238,7 +238,8 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
|
||||
@ -469,6 +469,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -35,7 +35,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -386,6 +387,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -941,6 +941,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Tie weights with input embeddings if using same dimensions
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user