diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index e363be523dcce..3cbdd64acc4a9 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -19,6 +19,7 @@ from transformers import ( ) from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -379,6 +380,7 @@ class SiglipAttention(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], ) -> None: super().__init__() @@ -413,8 +415,11 @@ class SiglipAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention( - self.num_heads_per_partition, self.head_dim, self.scale + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", ) def forward( @@ -424,25 +429,7 @@ class SiglipAttention(nn.Module): """Input shape: Batch x Time x Channel""" qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - - needs_unsqueeze = query_states.ndim == 2 - if needs_unsqueeze: - query_states, key_states, value_states = ( - query_states.unsqueeze(0), - key_states.unsqueeze(0), - value_states.unsqueeze(0), - ) - out = self.attn(query_states, key_states, value_states) - - if needs_unsqueeze: - out, query_states, key_states, value_states = ( - out.squeeze(0), - query_states.squeeze(0), - key_states.squeeze(0), - value_states.squeeze(0), - ) - attn_output, _ = self.out_proj(out) return attn_output, None @@ -495,6 +482,7 @@ class SiglipEncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], ) -> None: super().__init__() @@ -504,6 +492,7 @@ class SiglipEncoderLayer(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_cls=attn_cls, ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -539,6 +528,7 @@ class SiglipEncoder(nn.Module): num_hidden_layers_override: int | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], ) -> None: super().__init__() @@ -555,6 +545,7 @@ class SiglipEncoder(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + attn_cls=attn_cls, ) for layer_idx in range(num_hidden_layers) ] @@ -598,6 +589,7 @@ class SiglipTextTransformer(nn.Module): config=config, quant_config=quant_config, prefix=f"{prefix}.encoder", + attn_cls=EncoderOnlyAttention, ) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -709,6 +701,7 @@ class SiglipVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, ) num_hidden_layers = config.num_hidden_layers @@ -1034,10 +1027,56 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): inputs_embeds=inputs_embeds, ) text_features = self.text_model.head(last_hidden_state) - # Flip to extract CLS token (first token after reversal) for pooling - text_features = text_features.flip(0) + + # SigLIP uses reversed position_ids; + # flip sequences to move EOS token to first position + text_features = self._flip_sequences_by_position_ids( + text_features, position_ids + ) + return text_features + def _flip_sequences_by_position_ids( + self, + features: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """Flip sequences so EOS token moves to first position for CLS pooling. + + SigLIP position_ids are reversed within each sequence. This method detects + sequence boundaries and flips each sequence individually. + """ + if len(features) == 1: + return features + + # Detect sequence boundaries where position_ids decrease + position_diffs = position_ids[1:] - position_ids[:-1] + boundary_mask = position_diffs <= 0 + + boundary_indices = torch.cat( + [ + torch.tensor([0], device=features.device), + torch.where(boundary_mask)[0] + 1, + torch.tensor([len(features)], device=features.device), + ] + ) + + # For each sequence [start, end), position i flips to: start + end - 1 - i + lengths = boundary_indices[1:] - boundary_indices[:-1] + starts = boundary_indices[:-1] + ends = boundary_indices[1:] + + # Assign sequence ID to each element + sequence_ids = torch.arange( + len(lengths), device=features.device + ).repeat_interleave(lengths) + + # Calculate flipped indices for all positions at once + current_positions = torch.arange(len(features), device=features.device) + flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions + + return features[flip_indices] + def get_image_features( self, pixel_values: torch.Tensor,