mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[Bugfix] gemma[2,3] interleaved attention when sliding window is disabled (#17180)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
c53e0730cb
commit
8de2901fea
@ -145,8 +145,8 @@ class Gemma2Attention(nn.Module):
|
|||||||
# reference:
|
# reference:
|
||||||
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
use_sliding_window = (layer_idx % 2 == 0 and
|
use_sliding_window = (layer_idx % 2 == 0 and getattr(
|
||||||
config.interleaved_sliding_window is not None)
|
config, "interleaved_sliding_window", None) is not None)
|
||||||
sliding_window = config.interleaved_sliding_window if \
|
sliding_window = config.interleaved_sliding_window if \
|
||||||
use_sliding_window else None
|
use_sliding_window else None
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
|
|||||||
@ -146,7 +146,9 @@ class Gemma3Attention(nn.Module):
|
|||||||
|
|
||||||
# TODO(woosuk): Add reference to the original HF implementation.
|
# TODO(woosuk): Add reference to the original HF implementation.
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
|
self.is_sliding = (getattr(
|
||||||
|
config, "interleaved_sliding_window", None) is not None and bool(
|
||||||
|
(layer_idx + 1) % config.sliding_window_pattern))
|
||||||
# Initialize the rotary embedding.
|
# Initialize the rotary embedding.
|
||||||
if self.is_sliding:
|
if self.is_sliding:
|
||||||
# Local attention. Override the values in config.json.
|
# Local attention. Override the values in config.json.
|
||||||
|
|||||||
@ -478,7 +478,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
self.sliding_window = config.text_config.interleaved_sliding_window
|
self.sliding_window = getattr(config.text_config,
|
||||||
|
"interleaved_sliding_window", None)
|
||||||
|
|
||||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
@ -680,13 +681,14 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||||
global_attn_masks.append(global_attn_mask)
|
global_attn_masks.append(global_attn_mask)
|
||||||
|
|
||||||
# Create a local causal mask with sliding window (1024).
|
if self.sliding_window is not None:
|
||||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
# Create a local causal mask with sliding window (1024).
|
||||||
local_attn_mask = torch.tril(local_attn_mask,
|
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||||
diagonal=-self.sliding_window)
|
local_attn_mask = torch.tril(local_attn_mask,
|
||||||
local_attn_mask = torch.where(local_attn_mask == 0,
|
diagonal=-self.sliding_window)
|
||||||
global_attn_mask, float("-inf"))
|
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||||
local_attn_masks.append(local_attn_mask)
|
global_attn_mask, float("-inf"))
|
||||||
|
local_attn_masks.append(local_attn_mask)
|
||||||
kwargs["global_attn_masks"] = global_attn_masks
|
kwargs["global_attn_masks"] = global_attn_masks
|
||||||
kwargs["local_attn_masks"] = local_attn_masks
|
kwargs["local_attn_masks"] = local_attn_masks
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user