[bugfix] fix siglip batch text output error (#28365)

Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
Yu Jiaqi 2025-11-10 21:21:15 +08:00 committed by GitHub
parent 6f7de33bed
commit 15be507c86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,6 +19,7 @@ from transformers import (
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -379,6 +380,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()
@ -413,8 +415,11 @@ class SiglipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward(
@ -424,25 +429,7 @@ class SiglipAttention(nn.Module):
"""Input shape: Batch x Time x Channel"""
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
needs_unsqueeze = query_states.ndim == 2
if needs_unsqueeze:
query_states, key_states, value_states = (
query_states.unsqueeze(0),
key_states.unsqueeze(0),
value_states.unsqueeze(0),
)
out = self.attn(query_states, key_states, value_states)
if needs_unsqueeze:
out, query_states, key_states, value_states = (
out.squeeze(0),
query_states.squeeze(0),
key_states.squeeze(0),
value_states.squeeze(0),
)
attn_output, _ = self.out_proj(out)
return attn_output, None
@ -495,6 +482,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()
@ -504,6 +492,7 @@ class SiglipEncoderLayer(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -539,6 +528,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()
@ -555,6 +545,7 @@ class SiglipEncoder(nn.Module):
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
for layer_idx in range(num_hidden_layers)
]
@ -598,6 +589,7 @@ class SiglipTextTransformer(nn.Module):
config=config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
attn_cls=EncoderOnlyAttention,
)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@ -709,6 +701,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
)
num_hidden_layers = config.num_hidden_layers
@ -1034,10 +1027,56 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
inputs_embeds=inputs_embeds,
)
text_features = self.text_model.head(last_hidden_state)
# Flip to extract CLS token (first token after reversal) for pooling
text_features = text_features.flip(0)
# SigLIP uses reversed position_ids;
# flip sequences to move EOS token to first position
text_features = self._flip_sequences_by_position_ids(
text_features, position_ids
)
return text_features
def _flip_sequences_by_position_ids(
self,
features: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
"""Flip sequences so EOS token moves to first position for CLS pooling.
SigLIP position_ids are reversed within each sequence. This method detects
sequence boundaries and flips each sequence individually.
"""
if len(features) == 1:
return features
# Detect sequence boundaries where position_ids decrease
position_diffs = position_ids[1:] - position_ids[:-1]
boundary_mask = position_diffs <= 0
boundary_indices = torch.cat(
[
torch.tensor([0], device=features.device),
torch.where(boundary_mask)[0] + 1,
torch.tensor([len(features)], device=features.device),
]
)
# For each sequence [start, end), position i flips to: start + end - 1 - i
lengths = boundary_indices[1:] - boundary_indices[:-1]
starts = boundary_indices[:-1]
ends = boundary_indices[1:]
# Assign sequence ID to each element
sequence_ids = torch.arange(
len(lengths), device=features.device
).repeat_interleave(lengths)
# Calculate flipped indices for all positions at once
current_positions = torch.arange(len(features), device=features.device)
flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions
return features[flip_indices]
def get_image_features(
self,
pixel_values: torch.Tensor,