Remove upstream fa checks (#29471)

Signed-off-by: mingyuanm <mingyuanm@nvidia.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Mingyuan Ma 2025-11-28 05:52:42 -08:00 committed by GitHub
parent e2f56c309d
commit 460d8bbf2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 18 additions and 148 deletions

View File

@ -56,53 +56,28 @@ FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__) logger = init_logger(__name__)
def check_upstream_fa_availability(dtype: torch.dtype):
if (
dtype in (torch.float16, torch.bfloat16)
and current_platform.is_cuda()
and current_platform.has_device_capability(80)
):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
if current_platform.is_rocm():
from importlib.util import find_spec
return find_spec("flash_attn") is not None
return False
def maybe_get_vit_flash_attn_backend( def maybe_get_vit_flash_attn_backend(
attn_backend: AttentionBackendEnum, attn_backend: AttentionBackendEnum,
use_upstream_fa: bool,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]: ) -> tuple[AttentionBackendEnum, Callable | None]:
if current_platform.is_rocm(): if current_platform.is_rocm():
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
attn_backend = AttentionBackendEnum.ROCM_AITER_FA attn_backend = AttentionBackendEnum.ROCM_AITER_FA
elif ( elif (
check_upstream_fa_availability(torch.get_default_dtype()) attn_backend_override is None
and on_gfx9() and on_gfx9()
and attn_backend_override is None and attn_backend == AttentionBackendEnum.FLASH_ATTN
): ):
attn_backend = AttentionBackendEnum.FLASH_ATTN pass
use_upstream_fa = True
else: else:
return AttentionBackendEnum.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
elif current_platform.is_cuda(): elif current_platform.is_cuda():
if ( pass
attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu(): elif current_platform.is_xpu():
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, ( assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend." "XPU platform only supports FLASH_ATTN as vision attention backend."
) )
use_upstream_fa = False pass
else: else:
return AttentionBackendEnum.TORCH_SDPA, None return AttentionBackendEnum.TORCH_SDPA, None
@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
if use_upstream_fa: from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else: else:
flash_attn_varlen_func = None flash_attn_varlen_func = None
@ -501,11 +473,6 @@ class MultiHeadAttention(nn.Module):
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
self.attn_backend = ( self.attn_backend = (
backend backend
if backend if backend
@ -521,7 +488,6 @@ class MultiHeadAttention(nn.Module):
self.attn_backend, self._flash_attn_varlen_func = ( self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -531,17 +497,8 @@ class MultiHeadAttention(nn.Module):
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
} }
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
use_upstream_fa = True
logger.info_once( logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, " f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
f"use_upstream_fa: {use_upstream_fa}"
) )
def forward( def forward(

View File

@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor: ) -> torch.Tensor:
if is_rocm_aiter: if is_rocm_aiter:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
if use_upstream_fa: from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor: ) -> torch.Tensor:
b, s, h, d = q.shape b, s, h, d = q.shape
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
use_upstream_fa: bool,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
) )

View File

@ -18,6 +18,14 @@ elif current_platform.is_xpu():
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def get_flash_attn_version(requires_alibi: bool = False) -> int | None:

View File

@ -11,7 +11,6 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -294,12 +293,10 @@ class DotsVisionAttention(nn.Module):
torch.get_default_dtype(), torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -569,11 +566,6 @@ class DotsVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.out_hidden_size = config.hidden_size self.out_hidden_size = config.hidden_size
# Keep blocks for compatibility with other vision towers # Keep blocks for compatibility with other vision towers
num_layers = ( num_layers = (

View File

@ -38,7 +38,6 @@ from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -201,12 +200,9 @@ class Ernie4_5_VisionAttention(nn.Module):
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -498,11 +494,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -47,10 +47,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import maybe_get_vit_flash_attn_backend
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
@ -296,12 +293,10 @@ class Glm4vVisionAttention(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -730,11 +725,6 @@ class Glm4vVisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -418,7 +418,6 @@ class KeyeSiglipAttention(nn.Module):
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )

View File

@ -33,7 +33,6 @@ from transformers.utils import torch_int
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
@ -582,7 +581,6 @@ class SiglipAttention(nn.Module):
prefix: str = "", prefix: str = "",
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -612,11 +610,9 @@ class SiglipAttention(nn.Module):
) )
self.attn_backend = attn_backend self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -680,7 +676,6 @@ class SiglipAttention(nn.Module):
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
) )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
outputs = [] outputs = []
@ -783,7 +778,6 @@ class SiglipEncoderLayer(nn.Module):
*, *,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
use_upstream_fa: bool = False,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -796,7 +790,6 @@ class SiglipEncoderLayer(nn.Module):
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
attn_backend=attn_backend, attn_backend=attn_backend,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
use_upstream_fa=use_upstream_fa,
) )
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
@ -852,13 +845,6 @@ class SiglipEncoder(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
} and check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
@ -875,7 +861,6 @@ class SiglipEncoder(nn.Module):
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
use_upstream_fa=self.use_upstream_fa,
) )
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]

View File

@ -307,7 +307,6 @@ class Qwen2_5_VisionAttention(nn.Module):
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -344,24 +343,13 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp=use_data_parallel, disable_tp=use_data_parallel,
) )
self.attn_backend = attn_backend self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform
if (
current_platform.is_rocm()
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
):
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
@ -415,7 +403,6 @@ class Qwen2_5_VisionAttention(nn.Module):
max_seqlen, max_seqlen,
batch_size, batch_size,
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA, self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
self.use_upstream_fa,
) )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. # Execute attention entry by entry for speed & less VRAM.
@ -459,7 +446,6 @@ class Qwen2_5_VisionBlock(nn.Module):
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: AttentionBackendEnum | None = None, attn_backend_override: AttentionBackendEnum | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -475,7 +461,6 @@ class Qwen2_5_VisionBlock(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=attn_backend, attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.mlp = Qwen2_5_VisionMLP( self.mlp = Qwen2_5_VisionMLP(
@ -644,7 +629,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
is_neox_style=True, is_neox_style=True,
) )
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
@ -654,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -681,7 +664,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)

View File

@ -45,7 +45,6 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoP
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import ( from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend, maybe_get_vit_flash_attn_backend,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -335,12 +334,10 @@ class Qwen2VisionAttention(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )
@ -657,11 +654,6 @@ class Qwen2VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -47,7 +47,6 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
@ -381,11 +380,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -49,7 +49,6 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@ -202,7 +201,6 @@ class Qwen3_VisionBlock(nn.Module):
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA, attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -217,7 +215,6 @@ class Qwen3_VisionBlock(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=attn_backend, attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
) )
self.mlp = Qwen3_VisionMLP( self.mlp = Qwen3_VisionMLP(
dim, dim,
@ -378,14 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
use_upstream_fa = False
if (
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in { if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
@ -407,7 +396,6 @@ class Qwen3_VisionTransformer(nn.Module):
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
) )
for layer_idx in range(vision_config.depth) for layer_idx in range(vision_config.depth)
] ]

View File

@ -255,12 +255,10 @@ class Siglip2Attention(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
self.use_upstream_fa = False
self.attn_backend, self.flash_attn_varlen_func = ( self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
) )