diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 5a74e1310133d..f555147bc055a 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -28,7 +28,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - fa_version: int, + fa_version: int | None, ) -> torch.Tensor: kwargs = {} if is_rocm_aiter: @@ -36,7 +36,8 @@ def flash_attn_maxseqlen_wrapper( else: from vllm.attention.utils.fa_utils import flash_attn_varlen_func - kwargs["fa_version"] = fa_version + if not current_platform.is_rocm() and fa_version is not None: + kwargs["fa_version"] = fa_version q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, @@ -62,7 +63,7 @@ def flash_attn_maxseqlen_wrapper_fake( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - fa_version: int, + fa_version: int | None, ) -> torch.Tensor: return torch.empty_like(q) @@ -82,7 +83,7 @@ def vit_flash_attn_wrapper( max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, - fa_version: int, + fa_version: int | None, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version