mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:54:55 +08:00
Remove upstream fa checks (#29471)
Signed-off-by: mingyuanm <mingyuanm@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
e2f56c309d
commit
460d8bbf2d
@ -56,53 +56,28 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||
if (
|
||||
dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(80)
|
||||
):
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
return is_flash_attn_2_available()
|
||||
if current_platform.is_rocm():
|
||||
from importlib.util import find_spec
|
||||
|
||||
return find_spec("flash_attn") is not None
|
||||
return False
|
||||
|
||||
|
||||
def maybe_get_vit_flash_attn_backend(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
use_upstream_fa: bool,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> tuple[AttentionBackendEnum, Callable | None]:
|
||||
if current_platform.is_rocm():
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
elif (
|
||||
check_upstream_fa_availability(torch.get_default_dtype())
|
||||
attn_backend_override is None
|
||||
and on_gfx9()
|
||||
and attn_backend_override is None
|
||||
and attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
pass
|
||||
else:
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
|
||||
elif current_platform.is_cuda():
|
||||
if (
|
||||
attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
pass
|
||||
elif current_platform.is_xpu():
|
||||
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
|
||||
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||
)
|
||||
use_upstream_fa = False
|
||||
pass
|
||||
else:
|
||||
return AttentionBackendEnum.TORCH_SDPA, None
|
||||
|
||||
@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
|
||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
@ -501,11 +473,6 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
# Some auto-selected backends can be upgraded
|
||||
# to upstream flash attention if available.
|
||||
# If vllm native fa is selected, we use it directly.
|
||||
use_upstream_fa = False
|
||||
|
||||
self.attn_backend = (
|
||||
backend
|
||||
if backend
|
||||
@ -521,7 +488,6 @@ class MultiHeadAttention(nn.Module):
|
||||
self.attn_backend, self._flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -531,17 +497,8 @@ class MultiHeadAttention(nn.Module):
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
# this condition is just to make sure that the
|
||||
# use_upstream_fa in the log is correct
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
use_upstream_fa = True
|
||||
|
||||
logger.info_once(
|
||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||
f"use_upstream_fa: {use_upstream_fa}"
|
||||
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
if is_rocm_aiter:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
if use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
b, s, h, d = q.shape
|
||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||
@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
use_upstream_fa: bool,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -18,6 +18,14 @@ elif current_platform.is_xpu():
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Rocm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
) from e
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
|
||||
@ -11,7 +11,6 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
@ -294,12 +293,10 @@ class DotsVisionAttention(nn.Module):
|
||||
torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -569,11 +566,6 @@ class DotsVisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.out_hidden_size = config.hidden_size
|
||||
# Keep blocks for compatibility with other vision towers
|
||||
num_layers = (
|
||||
|
||||
@ -38,7 +38,6 @@ from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
@ -201,12 +200,9 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -498,11 +494,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
@ -47,10 +47,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||
@ -296,12 +293,10 @@ class Glm4vVisionAttention(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -730,11 +725,6 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
@ -418,7 +418,6 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa=False,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
@ -33,7 +33,6 @@ from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
@ -582,7 +581,6 @@ class SiglipAttention(nn.Module):
|
||||
prefix: str = "",
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -612,11 +610,9 @@ class SiglipAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -680,7 +676,6 @@ class SiglipAttention(nn.Module):
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
outputs = []
|
||||
@ -783,7 +778,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
*,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
use_upstream_fa: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -796,7 +790,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_backend=attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
@ -852,13 +845,6 @@ class SiglipEncoder(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
@ -875,7 +861,6 @@ class SiglipEncoder(nn.Module):
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_backend=self.attn_backend,
|
||||
attn_backend_override=attn_backend_override,
|
||||
use_upstream_fa=self.use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
@ -307,7 +307,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -344,24 +343,13 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = attn_backend
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
self.use_upstream_fa = True
|
||||
if current_platform.is_xpu():
|
||||
self.use_upstream_fa = False
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
@ -415,7 +403,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
@ -459,7 +446,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -475,7 +461,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Qwen2_5_VisionMLP(
|
||||
@ -644,7 +629,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
use_upstream_fa = False
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
@ -654,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -681,7 +664,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
|
||||
@ -45,7 +45,6 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoP
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
@ -335,12 +334,10 @@ class Qwen2VisionAttention(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
@ -657,11 +654,6 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
@ -47,7 +47,6 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
@ -381,11 +380,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
@ -49,7 +49,6 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
@ -202,7 +201,6 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -217,7 +215,6 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
self.mlp = Qwen3_VisionMLP(
|
||||
dim,
|
||||
@ -378,14 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
use_upstream_fa = False
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
@ -407,7 +396,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa,
|
||||
)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
]
|
||||
|
||||
@ -255,12 +255,10 @@ class Siglip2Attention(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
self.use_upstream_fa,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user