mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:07:13 +08:00
[Bugfix] Fix ViT with FlashAttention on ROCm (#30703)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
ec154c36ee
commit
51e5b3e3c4
@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
|
||||
}
|
||||
|
||||
self.fa_version = None
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
if (
|
||||
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
self.fa_version = get_flash_attn_version()
|
||||
assert self._flash_attn_varlen_func is not None
|
||||
self._flash_attn_varlen_func = functools.partial(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user