[Model] Gemma3: Support untied word embeddings (#30827)

Signed-off-by: www-spam <panmahm@naver.com>
This commit is contained in:
KimHyemin 2025-12-18 00:11:18 +09:00 committed by GitHub
parent b7b6a60aca
commit 196cdc3224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,10 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@ -532,12 +535,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping
)
@ -565,7 +576,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: