From 7cfd83ad92092c957f4c4745fda7f78f675153c9 Mon Sep 17 00:00:00 2001 From: Oscar Gonzalez Date: Tue, 2 Dec 2025 01:09:19 -0500 Subject: [PATCH] 1. Remove upstream fa checks (#29471) 2. Remove deprecated xformers (#29262) 3. Updated _get_prompt_updates() Signed-off-by: Oscar Gonzalez --- vllm/model_executor/models/isaac.py | 46 ++++------------------------- 1 file changed, 5 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index 82dae62cb56e4..e5a2d5440724a 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -23,12 +23,8 @@ from typing_extensions import TypedDict, Unpack from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) -from vllm.attention.ops.vit_attn_wrappers import ( - vit_xformers_attn_wrapper, -) from vllm.config import VllmConfig from vllm.config.model import ModelConfig from vllm.distributed import parallel_state @@ -73,6 +69,7 @@ from vllm.multimodal.processing import ( BaseProcessingInfo, PromptReplacement, PromptUpdate, + PromptUpdateDetails, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -1204,14 +1201,7 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: - # hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - hf_config = self.info.get_hf_config() image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - placeholder_id = tokenizer.encode( - hf_config.vision_token, - add_special_tokens=False, - ) pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2) merge_length = pixel_shuffle_scale**2 @@ -1221,13 +1211,14 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): grid_thw = out_item["image_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) - num_tokens = int(grid_thw.prod()) // merge_length - return placeholder_id * num_tokens + feature_size = int(grid_thw.prod()) // merge_length + repl_full = "<|image_pad|>" * feature_size + return PromptUpdateDetails.select_text(repl_full, "<|image_pad|>") return [ PromptReplacement( modality="image", - target=placeholder_id, + target="", replacement=get_replacement_isaac, ) ] @@ -1259,7 +1250,6 @@ class Siglip2VisionAttention(nn.Module): *, prefix: str = "", use_data_parallel: bool = False, - use_upstream_fa: bool = False, attn_backend: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None, ) -> None: @@ -1296,19 +1286,11 @@ class Siglip2VisionAttention(nn.Module): disable_tp=use_data_parallel, ) - self.use_upstream_fa = use_upstream_fa self.attn_backend = attn_backend - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } and check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN - self.use_upstream_fa = True if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -1317,7 +1299,6 @@ class Siglip2VisionAttention(nn.Module): self.attn_backend, self.flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( self.attn_backend, - self.use_upstream_fa, attn_backend_override=attn_backend_override, ) ) @@ -1389,10 +1370,6 @@ class Siglip2VisionAttention(nn.Module): context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - if seqlens is None: - raise ValueError("xFormers attention backend requires seqlens tensor.") - context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) else: raise RuntimeError( f"Isaac vision embedding does not support {self.attn_backend} backend." @@ -1412,7 +1389,6 @@ class Siglip2EncoderLayer(nn.Module): prefix: str = "", attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend_override: AttentionBackendEnum | None = None, - use_upstream_fa: bool = False, use_data_parallel: bool = False, ) -> None: super().__init__() @@ -1423,7 +1399,6 @@ class Siglip2EncoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", use_data_parallel=use_data_parallel, - use_upstream_fa=use_upstream_fa, attn_backend=attn_backend, attn_backend_override=attn_backend_override, ) @@ -1481,17 +1456,9 @@ class Siglip2Encoder(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - if self.attn_backend not in { - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.ROCM_AITER_FA, - } and check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = AttentionBackendEnum.FLASH_ATTN - self.use_upstream_fa = True if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -1505,7 +1472,6 @@ class Siglip2Encoder(nn.Module): prefix=f"{prefix}.layers.{layer_idx}", attn_backend=self.attn_backend, attn_backend_override=attn_backend_override, - use_upstream_fa=self.use_upstream_fa, use_data_parallel=use_data_parallel, ) for layer_idx in range(config.num_hidden_layers) @@ -1565,8 +1531,6 @@ class Siglip2VisionTransformer(nn.Module): AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.encoder.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens def forward(