[Bugfix] Fix ViT with FlashAttention on ROCm (#30703)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-15 14:45:21 -05:00 committed by GitHub
parent ec154c36ee
commit 51e5b3e3c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
}
self.fa_version = None
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(