[Bugfix][Qwen][DCA] fixes bug in dual-chunk-flash-attn backend for qwen 1m models. (#21364)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He 2025-07-23 21:34:37 +08:00 committed by GitHub
parent f59ec35b7f
commit 7c734ee09b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1055,7 +1055,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
block_table=block_table,
stage="intra",
vertical_indices=vertical_buffer,
slash_indices=slash_buffer,
@ -1070,7 +1069,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
block_table=block_table,
stage="intra",
vertical_indices=intra_vertical_indices,
slash_indices=intra_slash_indices,
@ -1085,7 +1083,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_succ,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="succ",
vertical_indices=succ_vertical_buffer,
slash_indices=succ_slash_buffer,
@ -1100,7 +1097,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_succ,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="succ",
vertical_indices=succ_vertical_indices,
slash_indices=succ_slash_indices,
@ -1115,7 +1111,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_inter,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="inter",
vertical_indices=inter_vertical_buffer,
slash_indices=inter_slash_buffer,
@ -1130,7 +1125,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_inter,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="inter",
vertical_indices=inter_vertical_indices,
slash_indices=inter_slash_indices,
@ -1151,7 +1145,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value_states: torch.Tensor,
softmax_scale: float,
causal: bool = True,
block_table: torch.Tensor = None,
max_seqlen_k: Optional[int] = None,
stage: str = "intra",
vertical_indices: Optional[torch.Tensor] = None,
@ -1230,7 +1223,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
block_table=block_table.unsqueeze(0),
return_softmax_lse=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,