mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
Remove upstream fa checks (#29471)
Signed-off-by: mingyuanm <mingyuanm@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
e2f56c309d
commit
460d8bbf2d
@ -56,53 +56,28 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_upstream_fa_availability(dtype: torch.dtype):
|
|
||||||
if (
|
|
||||||
dtype in (torch.float16, torch.bfloat16)
|
|
||||||
and current_platform.is_cuda()
|
|
||||||
and current_platform.has_device_capability(80)
|
|
||||||
):
|
|
||||||
from transformers.utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
return is_flash_attn_2_available()
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
from importlib.util import find_spec
|
|
||||||
|
|
||||||
return find_spec("flash_attn") is not None
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_get_vit_flash_attn_backend(
|
def maybe_get_vit_flash_attn_backend(
|
||||||
attn_backend: AttentionBackendEnum,
|
attn_backend: AttentionBackendEnum,
|
||||||
use_upstream_fa: bool,
|
|
||||||
attn_backend_override: AttentionBackendEnum | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
) -> tuple[AttentionBackendEnum, Callable | None]:
|
) -> tuple[AttentionBackendEnum, Callable | None]:
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||||
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
|
attn_backend = AttentionBackendEnum.ROCM_AITER_FA
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
check_upstream_fa_availability(torch.get_default_dtype())
|
attn_backend_override is None
|
||||||
and on_gfx9()
|
and on_gfx9()
|
||||||
and attn_backend_override is None
|
and attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
):
|
):
|
||||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
pass
|
||||||
use_upstream_fa = True
|
|
||||||
else:
|
else:
|
||||||
return AttentionBackendEnum.TORCH_SDPA, None
|
return AttentionBackendEnum.TORCH_SDPA, None
|
||||||
|
|
||||||
elif current_platform.is_cuda():
|
elif current_platform.is_cuda():
|
||||||
if (
|
pass
|
||||||
attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
use_upstream_fa = True
|
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
|
assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
|
||||||
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
"XPU platform only supports FLASH_ATTN as vision attention backend."
|
||||||
)
|
)
|
||||||
use_upstream_fa = False
|
pass
|
||||||
else:
|
else:
|
||||||
return AttentionBackendEnum.TORCH_SDPA, None
|
return AttentionBackendEnum.TORCH_SDPA, None
|
||||||
|
|
||||||
@ -113,10 +88,7 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
if use_upstream_fa:
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
@ -501,11 +473,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Some auto-selected backends can be upgraded
|
|
||||||
# to upstream flash attention if available.
|
|
||||||
# If vllm native fa is selected, we use it directly.
|
|
||||||
use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend = (
|
self.attn_backend = (
|
||||||
backend
|
backend
|
||||||
if backend
|
if backend
|
||||||
@ -521,7 +488,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.attn_backend, self._flash_attn_varlen_func = (
|
self.attn_backend, self._flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -531,17 +497,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
}
|
}
|
||||||
|
|
||||||
# this condition is just to make sure that the
|
|
||||||
# use_upstream_fa in the log is correct
|
|
||||||
if (
|
|
||||||
current_platform.is_rocm()
|
|
||||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
|
||||||
):
|
|
||||||
use_upstream_fa = True
|
|
||||||
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
|
||||||
f"use_upstream_fa: {use_upstream_fa}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -27,15 +27,11 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
use_upstream_fa: bool,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if is_rocm_aiter:
|
if is_rocm_aiter:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
if use_upstream_fa:
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
|
||||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
@ -62,7 +58,6 @@ def flash_attn_maxseqlen_wrapper_fake(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
use_upstream_fa: bool,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
b, s, h, d = q.shape
|
b, s, h, d = q.shape
|
||||||
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
|
||||||
@ -83,10 +78,9 @@ def vit_flash_attn_wrapper(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
use_upstream_fa: bool,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
|
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,14 @@ elif current_platform.is_xpu():
|
|||||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Rocm platform requires upstream flash-attn "
|
||||||
|
"to be installed. Please install flash-attn first."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor
|
|||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import (
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
maybe_get_vit_flash_attn_backend,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -294,12 +293,10 @@ class DotsVisionAttention(nn.Module):
|
|||||||
torch.get_default_dtype(),
|
torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -569,11 +566,6 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
self.out_hidden_size = config.hidden_size
|
self.out_hidden_size = config.hidden_size
|
||||||
# Keep blocks for compatibility with other vision towers
|
# Keep blocks for compatibility with other vision towers
|
||||||
num_layers = (
|
num_layers = (
|
||||||
|
|||||||
@ -38,7 +38,6 @@ from transformers import BatchFeature
|
|||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import (
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
maybe_get_vit_flash_attn_backend,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -201,12 +200,9 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -498,11 +494,6 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -47,10 +47,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor
|
|||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
|
||||||
)
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
|
||||||
@ -296,12 +293,10 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -730,11 +725,6 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -418,7 +418,6 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
use_upstream_fa=False,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -33,7 +33,6 @@ from transformers.utils import torch_int
|
|||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import (
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
maybe_get_vit_flash_attn_backend,
|
||||||
)
|
)
|
||||||
from vllm.attention.ops.vit_attn_wrappers import (
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
@ -582,7 +581,6 @@ class SiglipAttention(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
attn_backend_override: AttentionBackendEnum | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
use_upstream_fa: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -612,11 +610,9 @@ class SiglipAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
self.use_upstream_fa = use_upstream_fa
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -680,7 +676,6 @@ class SiglipAttention(nn.Module):
|
|||||||
max_seqlen,
|
max_seqlen,
|
||||||
batch_size,
|
batch_size,
|
||||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
self.use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -783,7 +778,6 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
*,
|
*,
|
||||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
attn_backend_override: AttentionBackendEnum | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
use_upstream_fa: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
@ -796,7 +790,6 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
use_upstream_fa=use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(
|
self.mlp = SiglipMLP(
|
||||||
@ -852,13 +845,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
|
||||||
if self.attn_backend not in {
|
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
|
||||||
} and check_upstream_fa_availability(torch.get_default_dtype()):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
self.use_upstream_fa = True
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.TORCH_SDPA,
|
AttentionBackendEnum.TORCH_SDPA,
|
||||||
@ -875,7 +861,6 @@ class SiglipEncoder(nn.Module):
|
|||||||
prefix=f"{prefix}.layers.{layer_idx}",
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
use_upstream_fa=self.use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -307,7 +307,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
use_upstream_fa: bool = False,
|
|
||||||
attn_backend_override: AttentionBackendEnum | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -344,24 +343,13 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
disable_tp=use_data_parallel,
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
self.use_upstream_fa = use_upstream_fa
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if (
|
|
||||||
current_platform.is_rocm()
|
|
||||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
|
||||||
):
|
|
||||||
self.use_upstream_fa = True
|
|
||||||
if current_platform.is_xpu():
|
|
||||||
self.use_upstream_fa = False
|
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
AttentionBackendEnum.ROCM_AITER_FA,
|
AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
@ -415,7 +403,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
max_seqlen,
|
max_seqlen,
|
||||||
batch_size,
|
batch_size,
|
||||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||||
self.use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||||
# Execute attention entry by entry for speed & less VRAM.
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
@ -459,7 +446,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
use_upstream_fa: bool = False,
|
|
||||||
attn_backend_override: AttentionBackendEnum | None = None,
|
attn_backend_override: AttentionBackendEnum | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -475,7 +461,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2_5_VisionMLP(
|
self.mlp = Qwen2_5_VisionMLP(
|
||||||
@ -644,7 +629,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_upstream_fa = False
|
|
||||||
self.attn_backend = get_vit_attn_backend(
|
self.attn_backend = get_vit_attn_backend(
|
||||||
head_size=head_dim,
|
head_size=head_dim,
|
||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
@ -654,7 +638,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -681,7 +664,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
|
|||||||
@ -45,7 +45,6 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoP
|
|||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import (
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
maybe_get_vit_flash_attn_backend,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -335,12 +334,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -657,11 +654,6 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -47,7 +47,6 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
|
|||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
@ -381,11 +380,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -49,7 +49,6 @@ from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
|||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||||
@ -202,7 +201,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||||
use_upstream_fa: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -217,7 +215,6 @@ class Qwen3_VisionBlock(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
self.mlp = Qwen3_VisionMLP(
|
self.mlp = Qwen3_VisionMLP(
|
||||||
dim,
|
dim,
|
||||||
@ -378,14 +375,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
use_upstream_fa = False
|
|
||||||
if (
|
|
||||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
|
||||||
and self.attn_backend != AttentionBackendEnum.ROCM_AITER_FA
|
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
|
||||||
):
|
|
||||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
|
||||||
use_upstream_fa = True
|
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
AttentionBackendEnum.FLASH_ATTN,
|
AttentionBackendEnum.FLASH_ATTN,
|
||||||
@ -407,7 +396,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
|
||||||
)
|
)
|
||||||
for layer_idx in range(vision_config.depth)
|
for layer_idx in range(vision_config.depth)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -255,12 +255,10 @@ class Siglip2Attention(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.use_upstream_fa = False
|
|
||||||
|
|
||||||
self.attn_backend, self.flash_attn_varlen_func = (
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user