[attn][tiny fix] fix attn backend in MultiHeadAttention (#11463)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2024-12-24 20:39:36 +08:00 committed by GitHub
parent 461cde2080
commit 5c7963249d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -191,6 +191,7 @@ class MultiHeadAttention(nn.Module):
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
attn_backend = backend_name_to_enum(attn_backend.get_name())
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS