diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 40f6d100c767e..70f72b5cb9beb 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -39,7 +39,10 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -532,12 +535,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): super().__init__() self.config = config - # currently all existing Gemma models have `tie_word_embeddings` enabled - assert config.tie_word_embeddings self.quant_config = quant_config self.model = Gemma3Model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping ) @@ -565,7 +576,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - logits = self.logits_processor(self.model.embed_tokens, hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: