mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 06:09:09 +08:00
[Bugfix] Fix ViT with FlashAttention on ROCm (#30703)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
ec154c36ee
commit
51e5b3e3c4
@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.fa_version = None
|
self.fa_version = None
|
||||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
if (
|
||||||
|
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
|
and current_platform.is_cuda()
|
||||||
|
):
|
||||||
self.fa_version = get_flash_attn_version()
|
self.fa_version = get_flash_attn_version()
|
||||||
assert self._flash_attn_varlen_func is not None
|
assert self._flash_attn_varlen_func is not None
|
||||||
self._flash_attn_varlen_func = functools.partial(
|
self._flash_attn_varlen_func = functools.partial(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user