1. Remove upstream fa checks (#29471)

2. Remove deprecated xformers (#29262)
3. Updated _get_prompt_updates()

Signed-off-by: Oscar Gonzalez <ogonzal6@alumni.jh.edu>
This commit is contained in:
Oscar Gonzalez 2025-12-02 01:09:19 -05:00 committed by Yang Liu
parent c10f5653ba
commit 7cfd83ad92

View File

@ -23,12 +23,8 @@ from typing_extensions import TypedDict, Unpack
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import (
vit_xformers_attn_wrapper,
)
from vllm.config import VllmConfig
from vllm.config.model import ModelConfig
from vllm.distributed import parallel_state
@ -73,6 +69,7 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
@ -1204,14 +1201,7 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
# hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
placeholder_id = tokenizer.encode(
hf_config.vision_token,
add_special_tokens=False,
)
pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2)
merge_length = pixel_shuffle_scale**2
@ -1221,13 +1211,14 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor):
grid_thw = out_item["image_grid_thw"].data
assert isinstance(grid_thw, torch.Tensor)
num_tokens = int(grid_thw.prod()) // merge_length
return placeholder_id * num_tokens
feature_size = int(grid_thw.prod()) // merge_length
repl_full = "<|image_pad|>" * feature_size
return PromptUpdateDetails.select_text(repl_full, "<|image_pad|>")
return [
PromptReplacement(
modality="image",
target=placeholder_id,
target="<image>",
replacement=get_replacement_isaac,
)
]
@ -1259,7 +1250,6 @@ class Siglip2VisionAttention(nn.Module):
*,
prefix: str = "",
use_data_parallel: bool = False,
use_upstream_fa: bool = False,
attn_backend: AttentionBackendEnum | None = None,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
@ -1296,19 +1286,11 @@ class Siglip2VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
self.use_upstream_fa = use_upstream_fa
self.attn_backend = attn_backend
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -1317,7 +1299,6 @@ class Siglip2VisionAttention(nn.Module):
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -1389,10 +1370,6 @@ class Siglip2VisionAttention(nn.Module):
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
if seqlens is None:
raise ValueError("xFormers attention backend requires seqlens tensor.")
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
else:
raise RuntimeError(
f"Isaac vision embedding does not support {self.attn_backend} backend."
@ -1412,7 +1389,6 @@ class Siglip2EncoderLayer(nn.Module):
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -1423,7 +1399,6 @@ class Siglip2EncoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel,
use_upstream_fa=use_upstream_fa,
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
)
@ -1481,17 +1456,9 @@ class Siglip2Encoder(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.XFORMERS,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
@ -1505,7 +1472,6 @@ class Siglip2Encoder(nn.Module):
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
use_data_parallel=use_data_parallel,
)
for layer_idx in range(config.num_hidden_layers)
@ -1565,8 +1531,6 @@ class Siglip2VisionTransformer(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.encoder.attn_backend == AttentionBackendEnum.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return max_seqlen, seqlens
def forward(