diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 55c96f649fbeb..fb8eccc55078a 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -59,16 +59,23 @@ class Gemma3MLP(nn.Module): intermediate_size: int, hidden_activation: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_activation != "gelu_pytorch_tanh": raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " @@ -125,12 +132,14 @@ class Gemma3Attention(nn.Module): self.total_num_kv_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -293,6 +302,7 @@ class Gemma3DecoderLayer(nn.Module): intermediate_size=config.intermediate_size, hidden_activation=config.hidden_activation, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -344,6 +354,7 @@ class Gemma3Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers,