mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:25:21 +08:00
[Bugfix] Revert custom attention mask for gemma3-mm (#28995)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fe25772aa9
commit
64192d5624
@ -32,7 +32,6 @@ from vllm.transformers_utils.config import (
|
||||
try_get_generation_config,
|
||||
try_get_safetensors_metadata,
|
||||
try_get_tokenizer_config,
|
||||
uses_custom_attention_masks,
|
||||
uses_mrope,
|
||||
)
|
||||
from vllm.transformers_utils.gguf_utils import (
|
||||
@ -1625,10 +1624,6 @@ class ModelConfig:
|
||||
def uses_mrope(self) -> bool:
|
||||
return uses_mrope(self.hf_config)
|
||||
|
||||
@property
|
||||
def uses_custom_attention_masks(self) -> bool:
|
||||
return uses_custom_attention_masks(self.hf_config)
|
||||
|
||||
@property
|
||||
def is_multimodal_model(self) -> bool:
|
||||
return self.multimodal_config is not None
|
||||
|
||||
@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
|
||||
|
||||
return hidden_states
|
||||
|
||||
def generate_attention_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate custom attention masks for Gemma3 multimodal inputs.
|
||||
|
||||
This is called by V1 engine's gpu_model_runner during preprocessing
|
||||
to generate attention masks that allow bidirectional attention between
|
||||
image tokens while maintaining causal attention for text.
|
||||
"""
|
||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||
# This is a HACK. Fix this.
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i]
|
||||
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
start_idx = 0
|
||||
for seq_idx, seq_len in enumerate(seq_lens):
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
|
||||
# Find image token positions
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
# Create a global causal mask
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
# Fill the lower triangle with 0 (causal attention)
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Enable bidirectional attention between image tokens
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
# GGUF compatibility: config might be Gemma3TextConfig directly
|
||||
text_config = getattr(self.config, "text_config", self.config)
|
||||
sliding_window = text_config.sliding_window
|
||||
if sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024)
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
|
||||
return {
|
||||
"has_images": True,
|
||||
"seq_lens": seq_lens,
|
||||
"global_attn_masks": global_attn_masks,
|
||||
"local_attn_masks": local_attn_masks,
|
||||
}
|
||||
|
||||
def prepare_attn_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs["has_images"] = True
|
||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||
# This is a HACK. Fix this.
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i].item()
|
||||
if i < num_seqs - 1:
|
||||
end_idx = start_indices[i + 1].item()
|
||||
else:
|
||||
end_idx = len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
kwargs["seq_lens"] = seq_lens
|
||||
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
# Create a global causal mask.
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
# Fill the lower triangle with 0.
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Consider the bidirectional attention between image tokens.
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
sliding_window = self.config.text_config.sliding_window
|
||||
if sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024).
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
kwargs["global_attn_masks"] = global_attn_masks
|
||||
kwargs["local_attn_masks"] = local_attn_masks
|
||||
return kwargs
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -520,17 +520,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
|
||||
"""Detect if model uses custom attention mask generation for multimodal.
|
||||
|
||||
Some multimodal models require custom attention masks that enable
|
||||
bidirectional attention between image tokens while maintaining causal
|
||||
attention for text tokens. Currently applies to Gemma3 multimodal models.
|
||||
"""
|
||||
architectures = getattr(config, "architectures", [])
|
||||
return "Gemma3ForConditionalGeneration" in architectures
|
||||
|
||||
|
||||
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
|
||||
"""
|
||||
Update kwargs for AutoConfig initialization based on model_type
|
||||
|
||||
@ -324,7 +324,6 @@ class GPUModelRunner(
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.uses_mrope = model_config.uses_mrope
|
||||
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
model_config
|
||||
)
|
||||
@ -2352,24 +2351,6 @@ class GPUModelRunner(
|
||||
**self._init_model_kwargs(num_scheduled_tokens),
|
||||
**self._extract_mm_kwargs(scheduler_output),
|
||||
}
|
||||
|
||||
# Generate custom attention masks for models that require them.
|
||||
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
|
||||
# Check mm_features (mm_embeds is empty during decode).
|
||||
has_mm_features = any(
|
||||
req_state.mm_features for req_state in self.requests.values()
|
||||
)
|
||||
if (
|
||||
self.uses_custom_attention_masks
|
||||
and has_mm_features
|
||||
and hasattr(self.model, "generate_attention_masks")
|
||||
):
|
||||
mask_kwargs = self.model.generate_attention_masks(
|
||||
self.input_ids.gpu[:num_scheduled_tokens],
|
||||
self.positions.gpu[:num_scheduled_tokens],
|
||||
mask_dtype=self.model.dtype,
|
||||
)
|
||||
model_kwargs.update(mask_kwargs)
|
||||
elif self.enable_prompt_embeds and is_first_rank:
|
||||
# Get the input embeddings for the tokens that are not input embeds,
|
||||
# then put them into the appropriate positions.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user