From 201dc98acc5f9be105281e674b46c95ad68d9fe9 Mon Sep 17 00:00:00 2001 From: Seungduk Kim Date: Thu, 6 Nov 2025 16:07:36 +0900 Subject: [PATCH] Fix hard-coded parameter name in gemma3n.py (#27946) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Seungduk Kim Signed-off-by: Biswa Panda Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Biswa Panda Co-authored-by: Nicolò Lucchesi --- vllm/model_executor/models/gemma3n.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index f7a732e3a601c..547884f393eb0 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -357,8 +357,27 @@ class Gemma3nAttention(nn.Module): offset = 2 if self.sliding_window is not None else 1 kv_shared_layer_index = first_kv_shared_layer_idx - offset if kv_shared_layer_index >= 0: + # Different model wrappers expose layer parameters under + # different parent attributes. + # For example: + # - Gemma3nForCausalLM → parameters live under "model.layers" + # - Gemma3nForConditionalGeneration → + # under "language_model.model.layers" + # This logic extracts the portion of the parameter name + # *before* ".layers." + # so downstream code can consistently reference the correct + # model root regardless of which wrapper class was used. + if ".layers." in prefix: + param_name_before_layers = prefix.split(".layers.")[0] + else: + raise ValueError( + "Unexpected prefix format for Gemma3nAttention: " + f"'{prefix}'. The prefix is expected to contain " + "'.layers.' to correctly determine the KV sharing " + "target layer." + ) # Only the greater layer is required to specify sharing. - kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 + kv_sharing_target_layer_name = f"{param_name_before_layers}.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 self.rotary_emb = get_rope( self.head_dim,