From 8ac3a4148796648d206a46144aa0dacea8977d55 Mon Sep 17 00:00:00 2001 From: Huamin Li <3ericli@gmail.com> Date: Thu, 20 Nov 2025 23:53:30 -0800 Subject: [PATCH] [CI Failure] Fix Gemma3 RoPE configuration for sliding attention layers (#29111) Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Cyrus Leung --- vllm/model_executor/models/gemma3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 565719ae7faeb..4ad6fc89dcaf2 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -166,10 +166,12 @@ class Gemma3Attention(nn.Module): else: # Transformers v4 rope config. # Global attention. Use the values in config.json. - rope_parameters = config.rope_parameters.copy() + rope_parameters = config.rope_parameters # Local attention. Override the values in config.json. if self.is_sliding: - rope_parameters["rope_theta"] = config.rope_local_base_freq + rope_parameters = dict( + rope_type="default", rope_theta=config.rope_local_base_freq + ) self.rotary_emb = get_rope( self.head_dim,