From c7c3853e9e3b3e0341d251e5a63879ee849cf769 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 17 Dec 2025 18:30:44 -0800 Subject: [PATCH] fix cr Signed-off-by: Yang --- vllm/model_executor/models/isaac.py | 192 +++++++--------------------- 1 file changed, 44 insertions(+), 148 deletions(-) diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index 85d9568b89048..097363f83c4dd 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -18,11 +18,8 @@ from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfi from transformers.tokenization_utils import TensorType from typing_extensions import TypedDict, Unpack -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import ( - maybe_get_vit_flash_attn_backend, -) -from vllm.config import VllmConfig +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig, VllmConfig from vllm.config.model import ModelConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -51,7 +48,6 @@ from vllm.model_executor.models.utils import ( init_vllm_registered_model, maybe_prefix, ) -from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -97,11 +93,15 @@ class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): def create_cumulative_seq_lengths( seq_sizes: torch.Tensor, device: torch.device -) -> tuple[torch.Tensor, int]: +) -> tuple[torch.Tensor, torch.Tensor]: """Create cumulative sequence lengths for variable-length attention.""" cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 + max_seqlen = ( + seq_sizes.max() + if len(seq_sizes) > 0 + else torch.tensor(0, dtype=torch.int32, device=device) + ) return cu_seqlens, max_seqlen @@ -763,9 +763,6 @@ class IsaacImageProcessor: class IsaacProcessor: """Processor wrapper (tokenizer + IsaacImageProcessor).""" - attributes = ["tokenizer"] - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") - def __init__(self, image_processor=None, tokenizer=None, **kwargs): self.image_token = kwargs.pop("image_token", "") self.image_processor = image_processor or IsaacImageProcessor(kwargs) @@ -963,24 +960,6 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): ] -def all_gather_interleave(local_tensor: torch.Tensor, hidden_size: int, tp_size: int): - """All-gather the input tensor interleavely across model parallel group.""" - import torch.distributed as dist - - gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather( - gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group - ) - - gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors - ] - ordered_tensors = [ - tensor for pair in zip(*gathered_tensors_split) for tensor in pair - ] - return torch.cat(ordered_tensors, dim=-1) - - class Siglip2VisionAttention(nn.Module): def __init__( self, @@ -988,12 +967,15 @@ class Siglip2VisionAttention(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - use_data_parallel: bool = False, - attn_backend: AttentionBackendEnum | None = None, - attn_backend_override: AttentionBackendEnum | None = None, + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.tp_size = ( 1 if use_data_parallel @@ -1025,26 +1007,12 @@ class Siglip2VisionAttention(nn.Module): disable_tp=use_data_parallel, ) - self.attn_backend = attn_backend - - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Isaac vision embedding does not support {self.attn_backend} backend." - ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - attn_backend_override=attn_backend_override, - ) + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, ) - self.is_flash_attn_backend = self.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape @@ -1064,7 +1032,6 @@ class Siglip2VisionAttention(nn.Module): *, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor | None, - seqlens: torch.Tensor | None, ) -> torch.Tensor: batch_size, _, _ = hidden_states.shape if batch_size != 1: @@ -1074,45 +1041,14 @@ class Siglip2VisionAttention(nn.Module): q, k, v = self.split_qkv(x) q, k, v = (rearrange(t, "s b h d -> b s h d") for t in (q, k, v)) - if self.is_flash_attn_backend: - q, k, v = (rearrange(t, "b s ... -> (b s) ...") for t in (q, k, v)) - output = self.flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False, - ) - context_layer = rearrange( - output, "(b s) h d -> s b (h d)", b=batch_size - ).contiguous() - elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: - outputs = [] - for i in range(1, len(cu_seqlens)): - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = ( - rearrange(tensor, "b s h d -> b h s d") - for tensor in (q_i, k_i, v_i) - ) - output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) - output_i = rearrange(output_i, "b h s d -> b s h d") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() - else: - raise RuntimeError( - f"Isaac vision embedding does not support {self.attn_backend} backend." - ) + context_layer = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.out_proj(context_layer) output = rearrange(output, "s b d -> b s d") @@ -1126,9 +1062,7 @@ class Siglip2EncoderLayer(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, - attn_backend_override: AttentionBackendEnum | None = None, - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.embed_dim = config.hidden_size @@ -1137,9 +1071,7 @@ class Siglip2EncoderLayer(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel, - attn_backend=attn_backend, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -1154,7 +1086,6 @@ class Siglip2EncoderLayer(nn.Module): *, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor | None, - seqlens: torch.Tensor | None, ) -> torch.Tensor: residual = hidden_states @@ -1163,7 +1094,6 @@ class Siglip2EncoderLayer(nn.Module): hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - seqlens=seqlens, ) hidden_states = residual + hidden_states @@ -1182,36 +1112,17 @@ class Siglip2Encoder(nn.Module): quant_config: QuantizationConfig | None = None, *, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.config = config - embed_dim = config.hidden_size - num_heads = config.num_attention_heads - head_dim = embed_dim // num_heads - self.attn_backend = get_vit_attn_backend( - head_size=head_dim, - dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, - ) - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.ROCM_AITER_FA, - }: - raise RuntimeError( - f"Isaac vision embedding does not support {self.attn_backend} backend." - ) self.layers = nn.ModuleList( [ Siglip2EncoderLayer( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", - attn_backend=self.attn_backend, - attn_backend_override=attn_backend_override, - use_data_parallel=use_data_parallel, + multimodal_config=multimodal_config, ) for layer_idx in range(config.num_hidden_layers) ] @@ -1223,7 +1134,6 @@ class Siglip2Encoder(nn.Module): *, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, - seqlens: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: @@ -1231,7 +1141,6 @@ class Siglip2Encoder(nn.Module): hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - seqlens=seqlens, ) return hidden_states @@ -1242,8 +1151,7 @@ class Siglip2VisionTransformer(nn.Module): config: PixelShuffleSiglip2VisionConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, + multimodal_config: MultiModalConfig | None = None, ): super().__init__() self.config = config @@ -1256,22 +1164,10 @@ class Siglip2VisionTransformer(nn.Module): config, quant_config=quant_config, prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel, - attn_backend_override=attn_backend_override, + multimodal_config=multimodal_config, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor - ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - max_seqlen, seqlens = None, None - if self.encoder.attn_backend in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - }: - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - return max_seqlen, seqlens - def forward( self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor], @@ -1291,14 +1187,14 @@ class Siglip2VisionTransformer(nn.Module): # Add a pseudo batch dimension for the encoder hidden_states = hidden_states.unsqueeze(0) - cu_seqlens, _ = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens, max_seqlen = create_cumulative_seq_lengths( + seq_sizes, hidden_states.device + ) hidden_states = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - seqlens=seqlens, ) hidden_states = self.post_layernorm(hidden_states) @@ -1349,18 +1245,22 @@ class IsaacVisionEmbedding(nn.Module): hidden_dim: int, output_dim: int, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", ): super().__init__() self.transformer = Siglip2VisionTransformer( - vision_cfg, prefix=maybe_prefix(prefix, "vision_embedding") + vision_cfg, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "0"), + multimodal_config=multimodal_config, ) self.linear_fc1 = ColumnParallelLinear( hidden_dim, 4 * hidden_dim, bias=False, quant_config=quant_config, - prefix=maybe_prefix(prefix, "vision_embedding.1"), + prefix=maybe_prefix(prefix, "1"), return_bias=False, ) self.act = nn.SiLU() @@ -1369,7 +1269,7 @@ class IsaacVisionEmbedding(nn.Module): output_dim, bias=False, quant_config=quant_config, - prefix=maybe_prefix(prefix, "vision_embedding.3"), + prefix=maybe_prefix(prefix, "3"), return_bias=False, ) @@ -1457,11 +1357,6 @@ class IsaacForConditionalGeneration( vision_cfg = config.vision_config if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") - vision_cfg.preserve_original_pe = True - vision_cfg.use_rope = False - vision_cfg.hidden_stride = vision_cfg.pixel_shuffle_scale_factor - vision_cfg.window_size = 32 * 2 - vision_cfg.fullatt_block_indexes = None attn_impl = ( config.vision_attn_implementation if config.vision_attn_implementation is not None @@ -1476,6 +1371,7 @@ class IsaacForConditionalGeneration( hidden_dim=hidden_dim, output_dim=config.hidden_size, quant_config=quant_config, + multimodal_config=self.multimodal_config, prefix=maybe_prefix(prefix, "vision_embedding"), )