mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 23:51:18 +08:00
[Model] Gemma3: Support untied word embeddings (#30827)
Signed-off-by: www-spam <panmahm@naver.com>
This commit is contained in:
parent
b7b6a60aca
commit
196cdc3224
@ -39,7 +39,10 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
@ -532,12 +535,20 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
|
||||||
assert config.tie_word_embeddings
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Gemma3Model(
|
self.model = Gemma3Model(
|
||||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(
|
self.logits_processor = LogitsProcessor(
|
||||||
config.vocab_size, soft_cap=config.final_logit_softcapping
|
config.vocab_size, soft_cap=config.final_logit_softcapping
|
||||||
)
|
)
|
||||||
@ -565,7 +576,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
logits = self.logits_processor(self.model.embed_tokens, hidden_states)
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user