diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index 5c61e5bf48a70..d2d980a9aadf4 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -14,15 +14,30 @@ import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from transformers import PretrainedConfig, Qwen3Config from transformers.image_processing_utils import BatchFeature from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig from transformers.tokenization_utils import TensorType from typing_extensions import TypedDict, Unpack -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) +from vllm.attention.ops.vit_attn_wrappers import ( + vit_xformers_attn_wrapper, +) from vllm.config import VllmConfig -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -36,13 +51,14 @@ from vllm.model_executor.models.interfaces import ( ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM -from vllm.model_executor.models.siglip2navit import Siglip2Encoder +from vllm.model_executor.models.siglip import SiglipMLP from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, maybe_prefix, ) +from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -332,6 +348,16 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): self.num_patches = num_patches +def create_cumulative_seq_lengths( + seq_sizes: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, int]: + """Create cumulative sequence lengths for variable-length attention.""" + cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) + cu_seqlens[1:] = seq_sizes.cumsum(0) + max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 + return cu_seqlens, max_seqlen + + class Siglip2VariableSequenceEmbeddings(nn.Module): def __init__(self, config: PixelShuffleSiglip2VisionConfig): super().__init__() @@ -367,7 +393,7 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): align_corners = False antialias = True for spatial_shape in spatial_shapes: - height, width = spatial_shape + height, width = int(spatial_shape[0]), int(spatial_shape[1]) # Guard to ensure height and width are positive for torch.compile if height > 0 and width > 0: resized_pos_embed = F.interpolate( @@ -399,21 +425,16 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): ): seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches - # Apply patch embeddings - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) + target_weight = self.patch_embedding.weight + seq_patches = seq_patches.to( + device=target_weight.device, dtype=target_weight.dtype + ) + patch_embeds = self.patch_embedding(seq_patches) pos_embeds = self.positional_embeddings(packed_seq_patches) # Flatten patch embeddings to match positional embeddings format - batch_size, patches_per_image, embed_dim = patch_embeds.shape - - # For variable-length attention, we need to reshape to (total_tokens, embed_dim) - if batch_size != 1: - raise ValueError( - "Variable-length attention expects batch_size=1 for packed sequences" - ) - - patch_embeds = patch_embeds.view(batch_size * patches_per_image, embed_dim) + if patch_embeds.dim() == 3: + patch_embeds = patch_embeds.view(-1, patch_embeds.size(-1)) # Add positional embeddings to patch embeddings embeddings = patch_embeds + pos_embeds @@ -1162,6 +1183,313 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): ] +def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + return torch.cat(ordered_tensors, dim=-1) + + +class Siglip2VisionAttention(nn.Module): + def __init__( + self, + config: PixelShuffleSiglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", + use_data_parallel: bool = False, + use_upstream_fa: bool = False, + attn_backend: AttentionBackendEnum | None = None, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + super().__init__() + + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + config.hidden_size, config.num_attention_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + config.num_attention_heads, self.tp_size + ) + + self.qkv_proj = QKVParallelLinear( + hidden_size=config.hidden_size, + head_size=self.hidden_size_per_attention_head, + total_num_heads=config.num_attention_heads, + total_num_kv_heads=config.num_attention_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, + ) + self.out_proj = RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) + + self.use_upstream_fa = use_upstream_fa + self.attn_backend = attn_backend + + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } and check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Isaac vision embedding does not support {self.attn_backend} backend." + ) + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + attn_backend_override=attn_backend_override, + ) + ) + self.is_flash_attn_backend = self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size) + + q, k, v = qkv.chunk(3, dim=2) + + if self.tp_size > 1: + q = dist_utils.split_tensor_along_last_dim(q, self.tp_size)[self.tp_rank] + k = dist_utils.split_tensor_along_last_dim(k, self.tp_size)[self.tp_rank] + v = dist_utils.split_tensor_along_last_dim(v, self.tp_size)[self.tp_rank] + + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor | None, + seqlens: torch.Tensor | None, + ) -> torch.Tensor: + batch_size, _, _ = hidden_states.shape + if batch_size != 1: + raise ValueError("packed variable-length attention expects batch_size=1") + x = rearrange(hidden_states, "b s d -> s b d") + x, _ = self.qkv_proj(x) + q, k, v = self.split_qkv(x) + q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v)) + + if self.is_flash_attn_backend: + q, k, v = (rearrange(t, "b s ... -> (b s) ...") for t in (q, k, v)) + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + rearrange(tensor, "b s h d -> b h s d") + for tensor in (q_i, k_i, v_i) + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() + elif self.attn_backend == AttentionBackendEnum.XFORMERS: + if seqlens is None: + raise ValueError("xFormers attention backend requires seqlens tensor.") + context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) + else: + raise RuntimeError( + f"Isaac vision embedding does not support {self.attn_backend} backend." + ) + + output, _ = self.out_proj(context_layer) + output = rearrange(output, "s b d -> b s d") + return output + + +class Siglip2EncoderLayer(nn.Module): + def __init__( + self, + config: PixelShuffleSiglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", + attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, + attn_backend_override: AttentionBackendEnum | None = None, + use_upstream_fa: bool = False, + use_data_parallel: bool = False, + ) -> None: + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + use_upstream_fa=use_upstream_fa, + attn_backend=attn_backend, + attn_backend_override=attn_backend_override, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor | None, + seqlens: torch.Tensor | None, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Siglip2Encoder(nn.Module): + def __init__( + self, + config: PixelShuffleSiglip2VisionConfig, + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + super().__init__() + self.config = config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + head_dim = embed_dim // num_heads + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + self.use_upstream_fa = False + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } and check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.ROCM_AITER_FA, + }: + raise RuntimeError( + f"Isaac vision embedding does not support {self.attn_backend} backend." + ) + self.layers = nn.ModuleList( + [ + Siglip2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + attn_backend=self.attn_backend, + attn_backend_override=attn_backend_override, + use_upstream_fa=self.use_upstream_fa, + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + *, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + seqlens: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + return hidden_states + + class Siglip2VisionTransformer(nn.Module): def __init__( self, @@ -1169,7 +1497,7 @@ class Siglip2VisionTransformer(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -1187,6 +1515,19 @@ class Siglip2VisionTransformer(nn.Module): ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + max_seqlen, seqlens = None, None + if self.encoder.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + elif self.encoder.attn_backend == AttentionBackendEnum.XFORMERS: + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + return max_seqlen, seqlens + def forward( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor], @@ -1203,15 +1544,20 @@ class Siglip2VisionTransformer(nn.Module): # Get embeddings from packed sequence hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) - grid_thws = torch.tensor( - [[1, token_grids[0][0].item(), token_grids[0][1].item()]] - ) - last_hidden_state = self.encoder(hidden_states, grid_thws) - hidden_states = self.post_layernorm(last_hidden_state) - # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) + cu_seqlens, _ = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + hidden_states = self.encoder( + inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = self.post_layernorm(hidden_states) + if self.pixel_shuffle_scale_factor > 1: hidden_states = pixel_shuffle_varlen( x=hidden_states, @@ -1252,6 +1598,44 @@ class Siglip2VisionTransformer(nn.Module): return loaded_params +class IsaacVisionEmbedding(nn.Module): + def __init__( + self, + vision_cfg: PixelShuffleSiglip2VisionConfig, + hidden_dim: int, + output_dim: int, + prefix: str, + ): + super().__init__() + self.transformer = Siglip2VisionTransformer( + vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding") + ) + self.linear_fc1 = ColumnParallelLinear( + hidden_dim, + 4 * hidden_dim, + bias=False, + prefix=maybe_prefix(prefix, "vision_embedding.1"), + return_bias=False, + ) + self.act = nn.SiLU() + self.linear_fc2 = RowParallelLinear( + 4 * hidden_dim, + output_dim, + bias=False, + prefix=maybe_prefix(prefix, "vision_embedding.3"), + return_bias=False, + ) + + def forward( + self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + hidden_states = self.transformer(packed_seq_patches) + hidden_states = self.linear_fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_fc2(hidden_states) + return hidden_states + + @MULTIMODAL_REGISTRY.register_processor( IsaacMultiModalProcessor, info=IsaacProcessingInfo, @@ -1278,6 +1662,10 @@ class IsaacForConditionalGeneration( hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embedding.": "vision_embedding.", + "vision_embedding.0": "vision_embedding.transformer", + "vision_embedding.1": "vision_embedding.linear_fc1", + "vision_embedding.2": "vision_embedding.act", + "vision_embedding.3": "vision_embedding.linear_fc2", } ) @@ -1325,17 +1713,11 @@ class IsaacForConditionalGeneration( raise ValueError("IsaacConfig should always have vision_config") hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) - self.vision_embedding = nn.Sequential( - Siglip2VisionTransformer( - vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding") - ), - nn.Linear( - hidden_dim, - 4 * hidden_dim, - bias=False, - ), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + self.vision_embedding = IsaacVisionEmbedding( + vision_cfg=vision_cfg, + hidden_dim=hidden_dim, + output_dim=config.hidden_size, + prefix=prefix, ) def get_mrope_input_positions( @@ -1502,6 +1884,6 @@ class IsaacForConditionalGeneration( """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="vision_embedding.3", # The final linear layer + connector="vision_embedding.linear_fc2", # The final linear layer tower_model="vision_embedding", )