mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[Bugfix] Revert custom attention mask for gemma3-mm (#28995)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
fe25772aa9
commit
64192d5624
@ -32,7 +32,6 @@ from vllm.transformers_utils.config import (
|
|||||||
try_get_generation_config,
|
try_get_generation_config,
|
||||||
try_get_safetensors_metadata,
|
try_get_safetensors_metadata,
|
||||||
try_get_tokenizer_config,
|
try_get_tokenizer_config,
|
||||||
uses_custom_attention_masks,
|
|
||||||
uses_mrope,
|
uses_mrope,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.gguf_utils import (
|
from vllm.transformers_utils.gguf_utils import (
|
||||||
@ -1625,10 +1624,6 @@ class ModelConfig:
|
|||||||
def uses_mrope(self) -> bool:
|
def uses_mrope(self) -> bool:
|
||||||
return uses_mrope(self.hf_config)
|
return uses_mrope(self.hf_config)
|
||||||
|
|
||||||
@property
|
|
||||||
def uses_custom_attention_masks(self) -> bool:
|
|
||||||
return uses_custom_attention_masks(self.hf_config)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multimodal_model(self) -> bool:
|
def is_multimodal_model(self) -> bool:
|
||||||
return self.multimodal_config is not None
|
return self.multimodal_config is not None
|
||||||
|
|||||||
@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
|
|||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return []
|
return []
|
||||||
@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def generate_attention_masks(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
mask_dtype: torch.dtype,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Generate custom attention masks for Gemma3 multimodal inputs.
|
|
||||||
|
|
||||||
This is called by V1 engine's gpu_model_runner during preprocessing
|
|
||||||
to generate attention masks that allow bidirectional attention between
|
|
||||||
image tokens while maintaining causal attention for text.
|
|
||||||
"""
|
|
||||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
|
||||||
# This is a HACK. Fix this.
|
|
||||||
start_indices = (positions == 0).cpu().nonzero()
|
|
||||||
num_seqs = len(start_indices)
|
|
||||||
seq_lens = []
|
|
||||||
for i in range(num_seqs):
|
|
||||||
start_idx = start_indices[i]
|
|
||||||
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
|
|
||||||
seq_lens.append(end_idx - start_idx)
|
|
||||||
|
|
||||||
global_attn_masks = []
|
|
||||||
local_attn_masks = []
|
|
||||||
start_idx = 0
|
|
||||||
for seq_idx, seq_len in enumerate(seq_lens):
|
|
||||||
end_idx = start_idx + seq_len
|
|
||||||
input_token_ids = input_ids[start_idx:end_idx]
|
|
||||||
|
|
||||||
# Find image token positions
|
|
||||||
img_pos = input_token_ids == self.config.image_token_index
|
|
||||||
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
# Create a global causal mask
|
|
||||||
global_attn_mask = torch.empty(
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
seq_len,
|
|
||||||
seq_len,
|
|
||||||
dtype=mask_dtype,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
global_attn_mask.fill_(float("-inf"))
|
|
||||||
# Fill the lower triangle with 0 (causal attention)
|
|
||||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
|
||||||
|
|
||||||
# Enable bidirectional attention between image tokens
|
|
||||||
img_mask = torch.zeros_like(global_attn_mask)
|
|
||||||
img_mask[:, :, :, img_pos] += 1
|
|
||||||
img_mask[:, :, img_pos, :] += 1
|
|
||||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
|
||||||
global_attn_masks.append(global_attn_mask)
|
|
||||||
|
|
||||||
# GGUF compatibility: config might be Gemma3TextConfig directly
|
|
||||||
text_config = getattr(self.config, "text_config", self.config)
|
|
||||||
sliding_window = text_config.sliding_window
|
|
||||||
if sliding_window is not None:
|
|
||||||
# Create a local causal mask with sliding window (1024)
|
|
||||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
|
||||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
|
||||||
local_attn_mask = torch.where(
|
|
||||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
|
||||||
)
|
|
||||||
local_attn_masks.append(local_attn_mask)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"has_images": True,
|
|
||||||
"seq_lens": seq_lens,
|
|
||||||
"global_attn_masks": global_attn_masks,
|
|
||||||
"local_attn_masks": local_attn_masks,
|
|
||||||
}
|
|
||||||
|
|
||||||
def prepare_attn_masks(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
mask_dtype: torch.dtype,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kwargs["has_images"] = True
|
|
||||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
|
||||||
# This is a HACK. Fix this.
|
|
||||||
start_indices = (positions == 0).cpu().nonzero()
|
|
||||||
num_seqs = len(start_indices)
|
|
||||||
seq_lens = []
|
|
||||||
for i in range(num_seqs):
|
|
||||||
start_idx = start_indices[i].item()
|
|
||||||
if i < num_seqs - 1:
|
|
||||||
end_idx = start_indices[i + 1].item()
|
|
||||||
else:
|
|
||||||
end_idx = len(input_ids)
|
|
||||||
seq_lens.append(end_idx - start_idx)
|
|
||||||
kwargs["seq_lens"] = seq_lens
|
|
||||||
|
|
||||||
global_attn_masks = []
|
|
||||||
local_attn_masks = []
|
|
||||||
start_idx = 0
|
|
||||||
for seq_len in seq_lens:
|
|
||||||
end_idx = start_idx + seq_len
|
|
||||||
input_token_ids = input_ids[start_idx:end_idx]
|
|
||||||
start_idx = end_idx
|
|
||||||
# Create a global causal mask.
|
|
||||||
global_attn_mask = torch.empty(
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
seq_len,
|
|
||||||
seq_len,
|
|
||||||
dtype=mask_dtype,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
global_attn_mask.fill_(float("-inf"))
|
|
||||||
# Fill the lower triangle with 0.
|
|
||||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
|
||||||
|
|
||||||
# Consider the bidirectional attention between image tokens.
|
|
||||||
img_mask = torch.zeros_like(global_attn_mask)
|
|
||||||
img_pos = input_token_ids == self.config.image_token_index
|
|
||||||
img_mask[:, :, :, img_pos] += 1
|
|
||||||
img_mask[:, :, img_pos, :] += 1
|
|
||||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
|
||||||
global_attn_masks.append(global_attn_mask)
|
|
||||||
|
|
||||||
sliding_window = self.config.text_config.sliding_window
|
|
||||||
if sliding_window is not None:
|
|
||||||
# Create a local causal mask with sliding window (1024).
|
|
||||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
|
||||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
|
||||||
local_attn_mask = torch.where(
|
|
||||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
|
||||||
)
|
|
||||||
local_attn_masks.append(local_attn_mask)
|
|
||||||
kwargs["global_attn_masks"] = global_attn_masks
|
|
||||||
kwargs["local_attn_masks"] = local_attn_masks
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -520,17 +520,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
|
|
||||||
"""Detect if model uses custom attention mask generation for multimodal.
|
|
||||||
|
|
||||||
Some multimodal models require custom attention masks that enable
|
|
||||||
bidirectional attention between image tokens while maintaining causal
|
|
||||||
attention for text tokens. Currently applies to Gemma3 multimodal models.
|
|
||||||
"""
|
|
||||||
architectures = getattr(config, "architectures", [])
|
|
||||||
return "Gemma3ForConditionalGeneration" in architectures
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
|
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
|
||||||
"""
|
"""
|
||||||
Update kwargs for AutoConfig initialization based on model_type
|
Update kwargs for AutoConfig initialization based on model_type
|
||||||
|
|||||||
@ -324,7 +324,6 @@ class GPUModelRunner(
|
|||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = model_config.uses_mrope
|
||||||
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
|
|
||||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||||
model_config
|
model_config
|
||||||
)
|
)
|
||||||
@ -2352,24 +2351,6 @@ class GPUModelRunner(
|
|||||||
**self._init_model_kwargs(num_scheduled_tokens),
|
**self._init_model_kwargs(num_scheduled_tokens),
|
||||||
**self._extract_mm_kwargs(scheduler_output),
|
**self._extract_mm_kwargs(scheduler_output),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate custom attention masks for models that require them.
|
|
||||||
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
|
|
||||||
# Check mm_features (mm_embeds is empty during decode).
|
|
||||||
has_mm_features = any(
|
|
||||||
req_state.mm_features for req_state in self.requests.values()
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.uses_custom_attention_masks
|
|
||||||
and has_mm_features
|
|
||||||
and hasattr(self.model, "generate_attention_masks")
|
|
||||||
):
|
|
||||||
mask_kwargs = self.model.generate_attention_masks(
|
|
||||||
self.input_ids.gpu[:num_scheduled_tokens],
|
|
||||||
self.positions.gpu[:num_scheduled_tokens],
|
|
||||||
mask_dtype=self.model.dtype,
|
|
||||||
)
|
|
||||||
model_kwargs.update(mask_kwargs)
|
|
||||||
elif self.enable_prompt_embeds and is_first_rank:
|
elif self.enable_prompt_embeds and is_first_rank:
|
||||||
# Get the input embeddings for the tokens that are not input embeds,
|
# Get the input embeddings for the tokens that are not input embeds,
|
||||||
# then put them into the appropriate positions.
|
# then put them into the appropriate positions.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user