mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:25:01 +08:00
[Bugfix][Qwen][DCA] fixes bug in dual-chunk-flash-attn backend for qwen 1m models. (#21364)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
f59ec35b7f
commit
7c734ee09b
@ -1055,7 +1055,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_intra,
|
v_states_intra,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
block_table=block_table,
|
|
||||||
stage="intra",
|
stage="intra",
|
||||||
vertical_indices=vertical_buffer,
|
vertical_indices=vertical_buffer,
|
||||||
slash_indices=slash_buffer,
|
slash_indices=slash_buffer,
|
||||||
@ -1070,7 +1069,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_intra,
|
v_states_intra,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
block_table=block_table,
|
|
||||||
stage="intra",
|
stage="intra",
|
||||||
vertical_indices=intra_vertical_indices,
|
vertical_indices=intra_vertical_indices,
|
||||||
slash_indices=intra_slash_indices,
|
slash_indices=intra_slash_indices,
|
||||||
@ -1085,7 +1083,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_succ,
|
v_states_succ,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
block_table=block_table,
|
|
||||||
stage="succ",
|
stage="succ",
|
||||||
vertical_indices=succ_vertical_buffer,
|
vertical_indices=succ_vertical_buffer,
|
||||||
slash_indices=succ_slash_buffer,
|
slash_indices=succ_slash_buffer,
|
||||||
@ -1100,7 +1097,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_succ,
|
v_states_succ,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
block_table=block_table,
|
|
||||||
stage="succ",
|
stage="succ",
|
||||||
vertical_indices=succ_vertical_indices,
|
vertical_indices=succ_vertical_indices,
|
||||||
slash_indices=succ_slash_indices,
|
slash_indices=succ_slash_indices,
|
||||||
@ -1115,7 +1111,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_inter,
|
v_states_inter,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
block_table=block_table,
|
|
||||||
stage="inter",
|
stage="inter",
|
||||||
vertical_indices=inter_vertical_buffer,
|
vertical_indices=inter_vertical_buffer,
|
||||||
slash_indices=inter_slash_buffer,
|
slash_indices=inter_slash_buffer,
|
||||||
@ -1130,7 +1125,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
v_states_inter,
|
v_states_inter,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=False,
|
causal=False,
|
||||||
block_table=block_table,
|
|
||||||
stage="inter",
|
stage="inter",
|
||||||
vertical_indices=inter_vertical_indices,
|
vertical_indices=inter_vertical_indices,
|
||||||
slash_indices=inter_slash_indices,
|
slash_indices=inter_slash_indices,
|
||||||
@ -1151,7 +1145,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
value_states: torch.Tensor,
|
value_states: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
block_table: torch.Tensor = None,
|
|
||||||
max_seqlen_k: Optional[int] = None,
|
max_seqlen_k: Optional[int] = None,
|
||||||
stage: str = "intra",
|
stage: str = "intra",
|
||||||
vertical_indices: Optional[torch.Tensor] = None,
|
vertical_indices: Optional[torch.Tensor] = None,
|
||||||
@ -1230,7 +1223,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
device=query_states.device),
|
device=query_states.device),
|
||||||
max_seqlen_k=max_seqlen_k,
|
max_seqlen_k=max_seqlen_k,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
block_table=block_table.unsqueeze(0),
|
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
)
|
)
|
||||||
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
|
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user