mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 15:08:42 +08:00
[Model] Gemma3: Support untied word embeddings (#30827)
Signed-off-by: www-spam <panmahm@naver.com>
This commit is contained in:
parent
b7b6a60aca
commit
196cdc3224
@ -39,7 +39,10 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
@ -532,12 +535,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma3Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.vocab_size, soft_cap=config.final_logit_softcapping
|
||||
)
|
||||
@ -565,7 +576,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user