mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:56:00 +08:00
[Model] Add module name prefixes to gemma3 (#15889)
Signed-off-by: Bartholomew Sabat <bartek@recursal.ai> Co-authored-by: Bartholomew Sabat <bartek@recursal.ai>
This commit is contained in:
parent
38327cf454
commit
9ec8257914
@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_activation: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_activation != "gelu_pytorch_tanh":
|
||||
raise ValueError(
|
||||
"Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||
@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -344,6 +354,7 @@ class Gemma3Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user