[Bugfix] Revert custom attention mask for gemma3-mm (#28995)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py 2025-11-20 13:23:22 +08:00 committed by GitHub
parent fe25772aa9
commit 64192d5624
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1 additions and 172 deletions

View File

@ -32,7 +32,6 @@ from vllm.transformers_utils.config import (
try_get_generation_config,
try_get_safetensors_metadata,
try_get_tokenizer_config,
uses_custom_attention_masks,
uses_mrope,
)
from vllm.transformers_utils.gguf_utils import (
@ -1625,10 +1624,6 @@ class ModelConfig:
def uses_mrope(self) -> bool:
return uses_mrope(self.hf_config)
@property
def uses_custom_attention_masks(self) -> bool:
return uses_custom_attention_masks(self.hf_config)
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

View File

@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
return hidden_states
def generate_attention_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
) -> dict[str, Any]:
"""Generate custom attention masks for Gemma3 multimodal inputs.
This is called by V1 engine's gpu_model_runner during preprocessing
to generate attention masks that allow bidirectional attention between
image tokens while maintaining causal attention for text.
"""
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i]
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
seq_lens.append(end_idx - start_idx)
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_idx, seq_len in enumerate(seq_lens):
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
# Find image token positions
img_pos = input_token_ids == self.config.image_token_index
start_idx = end_idx
# Create a global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0 (causal attention)
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Enable bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# GGUF compatibility: config might be Gemma3TextConfig directly
text_config = getattr(self.config, "text_config", self.config)
sliding_window = text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024)
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
return {
"has_images": True,
"seq_lens": seq_lens,
"global_attn_masks": global_attn_masks,
"local_attn_masks": local_attn_masks,
}
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
sliding_window = self.config.text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -520,17 +520,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
return False
def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
"""Detect if model uses custom attention mask generation for multimodal.
Some multimodal models require custom attention masks that enable
bidirectional attention between image tokens while maintaining causal
attention for text tokens. Currently applies to Gemma3 multimodal models.
"""
architectures = getattr(config, "architectures", [])
return "Gemma3ForConditionalGeneration" in architectures
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
"""
Update kwargs for AutoConfig initialization based on model_type

View File

@ -324,7 +324,6 @@ class GPUModelRunner(
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
@ -2352,24 +2351,6 @@ class GPUModelRunner(
**self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output),
}
# Generate custom attention masks for models that require them.
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
# Check mm_features (mm_embeds is empty during decode).
has_mm_features = any(
req_state.mm_features for req_state in self.requests.values()
)
if (
self.uses_custom_attention_masks
and has_mm_features
and hasattr(self.model, "generate_attention_masks")
):
mask_kwargs = self.model.generate_attention_masks(
self.input_ids.gpu[:num_scheduled_tokens],
self.positions.gpu[:num_scheduled_tokens],
mask_dtype=self.model.dtype,
)
model_kwargs.update(mask_kwargs)
elif self.enable_prompt_embeds and is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.