Remove upstream fa checks (#29471)

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Mingyuan Ma 2025-11-28 05:52:42 -08:00 committed by GitHub
parent e2f56c309d
commit 460d8bbf2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 18 additions and 148 deletions

View File

@ -56,53 +56,28 @@ FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
def check_upstream_fa_availability(dtype: torch.dtype):
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and current_platform.has_device_capability(80)
):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False
def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum,
use_upstream_fa: bool,
attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif (
check_upstream_fa_availability(torch.get_default_dtype())
attn_backend_override is None
and on_gfx9()
and attn_backend_override is None
and attn_backend == AttentionBackendEnum.FLASH_ATTN
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda():
if (
attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
pass
elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
pass
else:
return AttentionBackendEnum.TORCH_SDPA, None
@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
@ -501,11 +473,6 @@ class MultiHeadAttention(nn.Module):
attn_backend_override=attn_backend_override,
)
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
self.attn_backend = (
backend
if backend
@ -521,7 +488,6 @@ class MultiHeadAttention(nn.Module):
self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -531,17 +497,8 @@ class MultiHeadAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA,
}
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True
logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}"
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
)
def forward(

View File

@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
)

View File

@ -18,6 +18,14 @@ elif current_platform.is_xpu():
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:

View File

@ -11,7 +11,6 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
@ -294,12 +293,10 @@ class DotsVisionAttention(nn.Module):
torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -569,11 +566,6 @@ class DotsVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers
num_layers = (

View File

@ -38,7 +38,6 @@ from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
@ -201,12 +200,9 @@ class Ernie4_5_VisionAttention(nn.Module):
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -498,11 +494,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -47,10 +47,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
@ -296,12 +293,10 @@ class Glm4vVisionAttention(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -730,11 +725,6 @@ class Glm4vVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -418,7 +418,6 @@ class KeyeSiglipAttention(nn.Module):
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)

View File

@ -33,7 +33,6 @@ from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import (
@ -582,7 +581,6 @@ class SiglipAttention(nn.Module):
prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
@ -612,11 +610,9 @@ class SiglipAttention(nn.Module):
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -680,7 +676,6 @@ class SiglipAttention(nn.Module):
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = []
@ -783,7 +778,6 @@ class SiglipEncoderLayer(nn.Module):
*,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
):
super().__init__()
self.embed_dim = config.hidden_size
@ -796,7 +790,6 @@ class SiglipEncoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
attn_backend=attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=use_upstream_fa,
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
@ -852,13 +845,6 @@ class SiglipEncoder(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
@ -875,7 +861,6 @@ class SiglipEncoder(nn.Module):
prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
)
for layer_idx in range(config.num_hidden_layers)
]

View File

@ -307,7 +307,6 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -344,24 +343,13 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
@ -415,7 +403,6 @@ class Qwen2_5_VisionAttention(nn.Module):
max_seqlen,
batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
@ -459,7 +446,6 @@ class Qwen2_5_VisionBlock(nn.Module):
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
) -> None:
super().__init__()
@ -475,7 +461,6 @@ class Qwen2_5_VisionBlock(nn.Module):
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
@ -644,7 +629,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
is_neox_style=True,
)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
@ -654,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -681,7 +664,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)

View File

@ -45,7 +45,6 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoP
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
@ -335,12 +334,10 @@ class Qwen2VisionAttention(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
@ -657,11 +654,6 @@ class Qwen2VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -47,7 +47,6 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
@ -381,11 +380,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -49,7 +49,6 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@ -202,7 +201,6 @@ class Qwen3_VisionBlock(nn.Module):
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -217,7 +215,6 @@ class Qwen3_VisionBlock(nn.Module):
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
)
self.mlp = Qwen3_VisionMLP(
dim,
@ -378,14 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
use_upstream_fa = False
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
@ -407,7 +396,6 @@ class Qwen3_VisionTransformer(nn.Module):
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(vision_config.depth)
]

View File

@ -255,12 +255,10 @@ class Siglip2Attention(nn.Module):
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)