mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 11:07:24 +08:00
[bugfix] fix siglip batch text output error (#28365)
Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
parent
6f7de33bed
commit
15be507c86
@ -19,6 +19,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -379,6 +380,7 @@ class SiglipAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -413,8 +415,11 @@ class SiglipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -424,25 +429,7 @@ class SiglipAttention(nn.Module):
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
needs_unsqueeze = query_states.ndim == 2
|
||||
if needs_unsqueeze:
|
||||
query_states, key_states, value_states = (
|
||||
query_states.unsqueeze(0),
|
||||
key_states.unsqueeze(0),
|
||||
value_states.unsqueeze(0),
|
||||
)
|
||||
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
|
||||
if needs_unsqueeze:
|
||||
out, query_states, key_states, value_states = (
|
||||
out.squeeze(0),
|
||||
query_states.squeeze(0),
|
||||
key_states.squeeze(0),
|
||||
value_states.squeeze(0),
|
||||
)
|
||||
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
@ -495,6 +482,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -504,6 +492,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -539,6 +528,7 @@ class SiglipEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -555,6 +545,7 @@ class SiglipEncoder(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
@ -598,6 +589,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=EncoderOnlyAttention,
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
@ -709,6 +701,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
@ -1034,10 +1027,56 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
text_features = self.text_model.head(last_hidden_state)
|
||||
# Flip to extract CLS token (first token after reversal) for pooling
|
||||
text_features = text_features.flip(0)
|
||||
|
||||
# SigLIP uses reversed position_ids;
|
||||
# flip sequences to move EOS token to first position
|
||||
text_features = self._flip_sequences_by_position_ids(
|
||||
text_features, position_ids
|
||||
)
|
||||
|
||||
return text_features
|
||||
|
||||
def _flip_sequences_by_position_ids(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Flip sequences so EOS token moves to first position for CLS pooling.
|
||||
|
||||
SigLIP position_ids are reversed within each sequence. This method detects
|
||||
sequence boundaries and flips each sequence individually.
|
||||
"""
|
||||
if len(features) == 1:
|
||||
return features
|
||||
|
||||
# Detect sequence boundaries where position_ids decrease
|
||||
position_diffs = position_ids[1:] - position_ids[:-1]
|
||||
boundary_mask = position_diffs <= 0
|
||||
|
||||
boundary_indices = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=features.device),
|
||||
torch.where(boundary_mask)[0] + 1,
|
||||
torch.tensor([len(features)], device=features.device),
|
||||
]
|
||||
)
|
||||
|
||||
# For each sequence [start, end), position i flips to: start + end - 1 - i
|
||||
lengths = boundary_indices[1:] - boundary_indices[:-1]
|
||||
starts = boundary_indices[:-1]
|
||||
ends = boundary_indices[1:]
|
||||
|
||||
# Assign sequence ID to each element
|
||||
sequence_ids = torch.arange(
|
||||
len(lengths), device=features.device
|
||||
).repeat_interleave(lengths)
|
||||
|
||||
# Calculate flipped indices for all positions at once
|
||||
current_positions = torch.arange(len(features), device=features.device)
|
||||
flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions
|
||||
|
||||
return features[flip_indices]
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user