[Bugfix] gemma[2,3] interleaved attention when sliding window is disabled (#17180)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-04-26 10:53:51 +08:00 committed by GitHub
parent c53e0730cb
commit 8de2901fea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 11 deletions

View File

@ -145,8 +145,8 @@ class Gemma2Attention(nn.Module):
# reference: # reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and use_sliding_window = (layer_idx % 2 == 0 and getattr(
config.interleaved_sliding_window is not None) config, "interleaved_sliding_window", None) is not None)
sliding_window = config.interleaved_sliding_window if \ sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None use_sliding_window else None
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,

View File

@ -146,7 +146,9 @@ class Gemma3Attention(nn.Module):
# TODO(woosuk): Add reference to the original HF implementation. # TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern) self.is_sliding = (getattr(
config, "interleaved_sliding_window", None) is not None and bool(
(layer_idx + 1) % config.sliding_window_pattern))
# Initialize the rotary embedding. # Initialize the rotary embedding.
if self.is_sliding: if self.is_sliding:
# Local attention. Override the values in config.json. # Local attention. Override the values in config.json.

View File

@ -478,7 +478,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.sliding_window = config.text_config.interleaved_sliding_window self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config, self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config, quant_config,
@ -680,13 +681,14 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask) global_attn_masks.append(global_attn_mask)
# Create a local causal mask with sliding window (1024). if self.sliding_window is not None:
local_attn_mask = torch.ones_like(global_attn_mask) # Create a local causal mask with sliding window (1024).
local_attn_mask = torch.tril(local_attn_mask, local_attn_mask = torch.ones_like(global_attn_mask)
diagonal=-self.sliding_window) local_attn_mask = torch.tril(local_attn_mask,
local_attn_mask = torch.where(local_attn_mask == 0, diagonal=-self.sliding_window)
global_attn_mask, float("-inf")) local_attn_mask = torch.where(local_attn_mask == 0,
local_attn_masks.append(local_attn_mask) global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks kwargs["local_attn_masks"] = local_attn_masks
return kwargs return kwargs