From e806178d2a9b65ebd536342d58097a825d066b9e Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Thu, 30 Oct 2025 00:54:44 -0700 Subject: [PATCH] [BugFix][VL] Fix FA selection on Qwen2.5-VL (#27790) Signed-off-by: zhewenli Co-authored-by: Roger Wang --- .buildkite/test-amd.yaml | 2 +- vllm/model_executor/models/qwen2_5_vl.py | 30 +++++++++++++++--------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 56e7b1083b17e..35bd4c99adb78 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -318,7 +318,7 @@ steps: - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index dfaeb663bbe2f..3d67653726bd8 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -43,10 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import ( - check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend, -) +from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_xformers_attn_wrapper, @@ -318,6 +315,7 @@ class Qwen2_5_VisionAttention(nn.Module): use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -358,8 +356,14 @@ class Qwen2_5_VisionAttention(nn.Module): maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) + # On ROCm with FLASH_ATTN backend, upstream flash_attn is used + from vllm.platforms import current_platform + + if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: + self.use_upstream_fa = True self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA, @@ -484,6 +488,7 @@ class Qwen2_5_VisionBlock(nn.Module): use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -499,6 +504,7 @@ class Qwen2_5_VisionBlock(nn.Module): use_data_parallel=use_data_parallel, attn_backend=attn_backend, use_upstream_fa=use_upstream_fa, + attn_backend_override=attn_backend_override, ) self.mlp = Qwen2_5_VisionMLP( dim, @@ -698,13 +704,14 @@ class Qwen2_5_VisionTransformer(nn.Module): dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa, + attn_backend_override=attn_backend_override, + ) + ) if self.attn_backend not in { _Backend.FLASH_ATTN, @@ -730,6 +737,7 @@ class Qwen2_5_VisionTransformer(nn.Module): use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, use_upstream_fa=use_upstream_fa, + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ]