mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 06:09:10 +08:00
[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
fdb09c77d6
commit
72fc8aa412
@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
|
|||||||
],
|
],
|
||||||
[
|
[
|
||||||
"The image shows a Venn diagram with three over",
|
"The image shows a Venn diagram with three over",
|
||||||
"The image shows a Venn diagram with three intersect",
|
"This image shows a Venn diagram with three over",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
"This image displays a gradient of colors ranging from",
|
"This image displays a gradient of colors ranging from",
|
||||||
"The image displays a gradient of colors ranging from",
|
"This image displays a gradient of colors forming a spectrum",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
|
|||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
|
||||||
CpuPlatform()), \
|
patch("vllm.model_executor.models.vision.current_platform",
|
||||||
patch("vllm.platforms.current_platform", CpuPlatform()):
|
CpuPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
|
||||||
RocmPlatform()), \
|
patch("vllm.model_executor.models.vision.current_platform",
|
||||||
patch("vllm.platforms.current_platform", RocmPlatform()), \
|
RocmPlatform()):
|
||||||
patch("vllm.attention.layer.current_platform", RocmPlatform()):
|
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
with patch("vllm.attention.selector.current_platform",
|
# Test CUDA with head_size=64 (divisible by 32)
|
||||||
CudaPlatform()), \
|
# - should use vLLM's FlashAttention
|
||||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||||
|
patch("vllm.model_executor.models.vision.current_platform",
|
||||||
|
CudaPlatform()):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == _Backend.XFORMERS
|
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
# Test CUDA with head_size=72 (not divisible by 32)
|
||||||
|
# - with upstream FA not available
|
||||||
|
# - should use xformers
|
||||||
|
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||||
|
patch("vllm.model_executor.models.vision.current_platform",
|
||||||
CudaPlatform()), \
|
CudaPlatform()), \
|
||||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||||
|
return_value=False):
|
||||||
attn = MultiHeadAttention(16, 72, scale=1)
|
attn = MultiHeadAttention(16, 72, scale=1)
|
||||||
assert attn.attn_backend == _Backend.XFORMERS
|
assert attn.attn_backend == _Backend.XFORMERS
|
||||||
|
|
||||||
|
# Test CUDA with head_size=72 (not divisible by 32)
|
||||||
|
# - with upstream FA available
|
||||||
|
# - should use upstream FA
|
||||||
|
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||||
|
patch("vllm.model_executor.models.vision.current_platform",
|
||||||
|
CudaPlatform()), \
|
||||||
|
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||||
|
return_value=True), \
|
||||||
|
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
|
||||||
|
{
|
||||||
|
'flash_attn_varlen_func': lambda *args, **kwargs: None
|
||||||
|
})()}):
|
||||||
|
attn = MultiHeadAttention(16, 72, scale=1)
|
||||||
|
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||||
|
|
||||||
|
|
||||||
def ref_attention(
|
def ref_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -55,6 +56,14 @@ def check_xformers_availability():
|
|||||||
return USE_XFORMERS_OPS
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||||
|
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
|
||||||
|
) and current_platform.has_device_capability(80):
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
return is_flash_attn_2_available()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module, AttentionLayerBase):
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
"""Attention layer.
|
"""Attention layer.
|
||||||
|
|
||||||
@ -349,29 +358,55 @@ class MultiHeadAttention(nn.Module):
|
|||||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
dtype = torch.get_default_dtype()
|
dtype = torch.get_default_dtype()
|
||||||
attn_backend = get_attn_backend(head_size,
|
|
||||||
dtype,
|
# Determine the attention backend
|
||||||
kv_cache_dtype=None,
|
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||||
block_size=16,
|
|
||||||
is_attention_free=False)
|
# Some auto-selected backends can be upgraded
|
||||||
backend = backend_name_to_enum(attn_backend.get_name())
|
# to upstream flash attention if available.
|
||||||
|
# If vllm native fa is selected, we use it directly.
|
||||||
|
use_upstream_fa = False
|
||||||
|
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
|
dtype):
|
||||||
|
backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
# currently, only torch_sdpa is supported on rocm
|
# currently, only torch_sdpa is supported on rocm
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
else:
|
else:
|
||||||
|
|
||||||
self.attn_backend = backend if backend in {
|
self.attn_backend = backend if backend in {
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.TORCH_SDPA,
|
||||||
_Backend.TORCH_SDPA_VLLM_V1,
|
_Backend.TORCH_SDPA_VLLM_V1,
|
||||||
_Backend.XFORMERS,
|
_Backend.XFORMERS,
|
||||||
_Backend.PALLAS_VLLM_V1,
|
_Backend.PALLAS_VLLM_V1,
|
||||||
_Backend.ROCM_AITER_FA,
|
_Backend.ROCM_AITER_FA,
|
||||||
} else current_platform.get_vit_attn_backend()
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if (self.attn_backend == _Backend.XFORMERS
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
and not check_xformers_availability()):
|
and not check_xformers_availability()):
|
||||||
self.attn_backend = _Backend.TORCH_SDPA
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
|
||||||
|
}:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
|
||||||
|
logger.info_once(
|
||||||
|
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||||
|
f"use_upstream_fa: {use_upstream_fa}")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -392,7 +427,31 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.XFORMERS:
|
if self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1,
|
||||||
|
}:
|
||||||
|
|
||||||
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||||
|
step=kv_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=key.device)
|
||||||
|
|
||||||
|
out = self._flash_attn_varlen_func(
|
||||||
|
query.flatten(0, 1),
|
||||||
|
key.flatten(0, 1),
|
||||||
|
value.flatten(0, 1),
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=kv_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
)
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
|
|
||||||
out = xops.memory_efficient_attention_forward(query,
|
out = xops.memory_efficient_attention_forward(query,
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -170,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.proj")
|
prefix=f"{prefix}.proj")
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.ROCM_AITER_FA
|
||||||
@ -233,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
@ -457,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
), "vit's config.hidden must be equal to config.embed_dim"
|
), "vit's config.hidden must be equal to config.embed_dim"
|
||||||
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
|
||||||
|
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
|||||||
Glm4vVideoProcessor)
|
Glm4vVideoProcessor)
|
||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||||
parallel_state)
|
parallel_state)
|
||||||
@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.TORCH_SDPA,
|
_Backend.TORCH_SDPA,
|
||||||
@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
# from vllm_flash_attn.flash_attn_interface import (
|
# from vllm_flash_attn.flash_attn_interface import (
|
||||||
# flash_attn_varlen_func)
|
# flash_attn_varlen_func)
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module):
|
|||||||
self.post_layernorm = RMSNorm(vision_config.hidden_size,
|
self.post_layernorm = RMSNorm(vision_config.hidden_size,
|
||||||
eps=vision_config.rms_norm_eps)
|
eps=vision_config.rms_norm_eps)
|
||||||
|
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
|
|||||||
BaseModelOutputWithPooling)
|
BaseModelOutputWithPooling)
|
||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Keye-VL does not support {self.attn_backend} backend now.")
|
f"Keye-VL does not support {self.attn_backend} backend now.")
|
||||||
@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
|||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -298,7 +299,16 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
disable_tp=use_data_parallel)
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.ROCM_AITER_FA
|
||||||
@ -359,7 +369,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
@ -628,7 +641,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
|||||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
|
||||||
Qwen2VLVideoProcessor)
|
Qwen2VLVideoProcessor)
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
@ -314,7 +315,16 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.proj")
|
prefix=f"{prefix}.proj")
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.hidden_size_per_attention_head,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.ROCM_AITER_FA
|
||||||
@ -374,7 +384,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
@ -628,7 +641,12 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.merger",
|
prefix=f"{prefix}.merger",
|
||||||
)
|
)
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from torch.nn import functional as F
|
|||||||
from transformers import Siglip2VisionConfig
|
from transformers import Siglip2VisionConfig
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.attention.layer import check_upstream_fa_availability
|
||||||
from vllm.config import QuantizationConfig
|
from vllm.config import QuantizationConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module):
|
|||||||
self.use_rope = config.use_rope
|
self.use_rope = config.use_rope
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend = get_vit_attn_backend(
|
||||||
|
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||||
|
self.use_upstream_fa = False
|
||||||
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||||
|
check_upstream_fa_availability(
|
||||||
|
torch.get_default_dtype()):
|
||||||
|
self.attn_backend = _Backend.FLASH_ATTN
|
||||||
|
self.use_upstream_fa = True
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||||
_Backend.ROCM_AITER_FA
|
_Backend.ROCM_AITER_FA
|
||||||
@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module):
|
|||||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from flash_attn import flash_attn_varlen_func
|
if self.use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
attn_output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||||
max_seqlen).reshape(seq_length, -1)
|
max_seqlen).reshape(seq_length, -1)
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.selector import get_env_variable_attn_backend
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
|
||||||
@ -68,17 +67,18 @@ def get_vision_encoder_info(
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
||||||
"""
|
"""
|
||||||
Get the available attention backend for Vision Transformer.
|
Get the available attention backend for Vision Transformer.
|
||||||
"""
|
"""
|
||||||
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
|
# Lazy import to avoid circular dependency
|
||||||
|
from vllm.attention.selector import get_env_variable_attn_backend
|
||||||
|
|
||||||
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
|
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
|
||||||
if selected_backend is not None:
|
if selected_backend is not None:
|
||||||
return selected_backend
|
return selected_backend
|
||||||
|
|
||||||
return current_platform.get_vit_attn_backend(support_fa)
|
return current_platform.get_vit_attn_backend(head_size, dtype)
|
||||||
|
|
||||||
|
|
||||||
def resolve_visual_encoder_outputs(
|
def resolve_visual_encoder_outputs(
|
||||||
|
|||||||
@ -209,18 +209,24 @@ class CudaPlatformBase(Platform):
|
|||||||
return torch.cuda.max_memory_allocated(device)
|
return torch.cuda.max_memory_allocated(device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
if cls.has_device_capability(80) and support_fa:
|
dtype: torch.dtype) -> _Backend:
|
||||||
from transformers.utils import is_flash_attn_2_available
|
if dtype not in (torch.float16, torch.bfloat16):
|
||||||
if is_flash_attn_2_available():
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
|
if cls.has_device_capability(80):
|
||||||
|
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||||
|
from vllm.attention.selector import is_attn_backend_supported
|
||||||
|
is_default_fa_supported = is_attn_backend_supported(
|
||||||
|
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
|
||||||
|
if is_default_fa_supported:
|
||||||
return _Backend.FLASH_ATTN
|
return _Backend.FLASH_ATTN
|
||||||
logger.warning_once(
|
else:
|
||||||
"Current `vllm-flash-attn` has a bug inside vision "
|
# Fallback to XFORMERS
|
||||||
"module, so we use xformers backend instead. You can "
|
return _Backend.XFORMERS
|
||||||
"run `pip install flash-attn` to use flash-attention "
|
else:
|
||||||
"backend.")
|
# Fallback for Volta/Turing GPUs or FA not supported
|
||||||
# Fallback for Volta/Turing GPUs or FA not supported
|
return _Backend.XFORMERS
|
||||||
return _Backend.XFORMERS
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
|
|||||||
@ -192,7 +192,8 @@ class Platform:
|
|||||||
return device_id
|
return device_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
|
dtype: torch.dtype) -> _Backend:
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -175,15 +175,15 @@ class RocmPlatform(Platform):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
|
def get_vit_attn_backend(cls, head_size: int,
|
||||||
if support_fa:
|
dtype: torch.dtype) -> _Backend:
|
||||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
and on_gfx9()):
|
and on_gfx9()):
|
||||||
# Note: AITER FA is only supported for Qwen-VL models.
|
# Note: AITER FA is only supported for Qwen-VL models.
|
||||||
# TODO: Add support for other VL models in their model class.
|
# TODO: Add support for other VL models in their model class.
|
||||||
return _Backend.ROCM_AITER_FA
|
return _Backend.ROCM_AITER_FA
|
||||||
if on_gfx9():
|
if on_gfx9():
|
||||||
return _Backend.FLASH_ATTN
|
return _Backend.FLASH_ATTN
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user