mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 17:27:07 +08:00
[BugFix][VL] Fix FA selection on Qwen2.5-VL (#27790)
Signed-off-by: zhewenli <zhewenli@meta.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
5be1bed790
commit
e806178d2a
@ -318,7 +318,7 @@ steps:
|
|||||||
|
|
||||||
- label: V1 Test entrypoints # 35min
|
- label: V1 Test entrypoints # 35min
|
||||||
timeout_in_minutes: 50
|
timeout_in_minutes: 50
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
# grade: Blocking
|
# grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
@ -43,10 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import (
|
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||||
check_upstream_fa_availability,
|
|
||||||
maybe_get_vit_flash_attn_backend,
|
|
||||||
)
|
|
||||||
from vllm.attention.ops.vit_attn_wrappers import (
|
from vllm.attention.ops.vit_attn_wrappers import (
|
||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
vit_xformers_attn_wrapper,
|
vit_xformers_attn_wrapper,
|
||||||
@ -318,6 +315,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||||
use_upstream_fa: bool = False,
|
use_upstream_fa: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
@ -358,8 +356,14 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
maybe_get_vit_flash_attn_backend(
|
maybe_get_vit_flash_attn_backend(
|
||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
self.use_upstream_fa,
|
self.use_upstream_fa,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
self.use_upstream_fa = True
|
||||||
self.is_flash_attn_backend = self.attn_backend in {
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
_Backend.ROCM_AITER_FA,
|
_Backend.ROCM_AITER_FA,
|
||||||
@ -484,6 +488,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||||
use_upstream_fa: bool = False,
|
use_upstream_fa: bool = False,
|
||||||
|
attn_backend_override: _Backend | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if norm_layer is None:
|
if norm_layer is None:
|
||||||
@ -499,6 +504,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=attn_backend,
|
attn_backend=attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
use_upstream_fa=use_upstream_fa,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
self.mlp = Qwen2_5_VisionMLP(
|
self.mlp = Qwen2_5_VisionMLP(
|
||||||
dim,
|
dim,
|
||||||
@ -698,13 +704,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
dtype=torch.get_default_dtype(),
|
dtype=torch.get_default_dtype(),
|
||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
if (
|
|
||||||
self.attn_backend != _Backend.FLASH_ATTN
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
and self.attn_backend != _Backend.ROCM_AITER_FA
|
maybe_get_vit_flash_attn_backend(
|
||||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
self.attn_backend,
|
||||||
):
|
use_upstream_fa,
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
attn_backend_override=attn_backend_override,
|
||||||
use_upstream_fa = True
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {
|
if self.attn_backend not in {
|
||||||
_Backend.FLASH_ATTN,
|
_Backend.FLASH_ATTN,
|
||||||
@ -730,6 +737,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
use_data_parallel=use_data_parallel,
|
use_data_parallel=use_data_parallel,
|
||||||
attn_backend=self.attn_backend,
|
attn_backend=self.attn_backend,
|
||||||
use_upstream_fa=use_upstream_fa,
|
use_upstream_fa=use_upstream_fa,
|
||||||
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
for layer_idx in range(depth)
|
for layer_idx in range(depth)
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user