mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 15:30:57 +08:00
[ROCm][Bugfix] Fix fa_version argument error in flash_attn_maxseqlen_wrapper for ROCm without aiter (#30909)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
parent
30bb19a760
commit
8da6ae49c3
@ -28,7 +28,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
fa_version: int | None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if is_rocm_aiter:
|
||||
@ -36,7 +36,8 @@ def flash_attn_maxseqlen_wrapper(
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
|
||||
kwargs["fa_version"] = fa_version
|
||||
if not current_platform.is_rocm() and fa_version is not None:
|
||||
kwargs["fa_version"] = fa_version
|
||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
@ -62,7 +63,7 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
fa_version: int | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@ -82,7 +83,7 @@ def vit_flash_attn_wrapper(
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int,
|
||||
fa_version: int | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user