[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend (Fix API Break and Local Attention Logic: affecting Llama4) (#19904)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-06-26 05:42:31 -07:00 committed by GitHub
parent 84c260caeb
commit 27c065df50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 23 deletions

View File

@ -306,6 +306,10 @@ class MultiHeadAttention(nn.Module):
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

View File

@ -243,8 +243,8 @@ class AiterFlashAttentionMetadataBuilder:
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_max_query_len = int(seqlens_q_local_np.max())
local_max_seq_len = int(virt_k_seqlens_np.max())
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
@ -253,6 +253,17 @@ class AiterFlashAttentionMetadataBuilder:
max_seq_len=local_max_seq_len,
causal=True)
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
dtype=torch.int32,
device=self.runner.device)
local_cu_seq_lens[1:] = torch.cumsum(
torch.from_numpy(virt_k_seqlens_np).to(
device=self.runner.device,
dtype=torch.int32,
non_blocking=True),
dim=0)
local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
@ -260,6 +271,7 @@ class AiterFlashAttentionMetadataBuilder:
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_cu_seq_lens=local_cu_seq_lens,
local_scheduler_metadata=local_scheduler_metadata,
)
@ -368,6 +380,7 @@ class AiterFlashAttentionMetadata:
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_cu_seq_lens: torch.Tensor
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@ -387,6 +400,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
@ -408,6 +422,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -478,12 +493,15 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
@ -541,7 +559,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
local_metadata.local_cu_seq_lens),
)
_, num_heads, head_size = query.shape