mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[Bugfix][Qwen][DCA] fixes bug in dual-chunk-flash-attn backend for qwen 1m models. (#21364)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
f59ec35b7f
commit
7c734ee09b
@ -1055,7 +1055,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_intra,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_table,
|
||||
stage="intra",
|
||||
vertical_indices=vertical_buffer,
|
||||
slash_indices=slash_buffer,
|
||||
@ -1070,7 +1069,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_intra,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_table,
|
||||
stage="intra",
|
||||
vertical_indices=intra_vertical_indices,
|
||||
slash_indices=intra_slash_indices,
|
||||
@ -1085,7 +1083,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_succ,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
block_table=block_table,
|
||||
stage="succ",
|
||||
vertical_indices=succ_vertical_buffer,
|
||||
slash_indices=succ_slash_buffer,
|
||||
@ -1100,7 +1097,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_succ,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
block_table=block_table,
|
||||
stage="succ",
|
||||
vertical_indices=succ_vertical_indices,
|
||||
slash_indices=succ_slash_indices,
|
||||
@ -1115,7 +1111,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_inter,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
block_table=block_table,
|
||||
stage="inter",
|
||||
vertical_indices=inter_vertical_buffer,
|
||||
slash_indices=inter_slash_buffer,
|
||||
@ -1130,7 +1125,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
v_states_inter,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
block_table=block_table,
|
||||
stage="inter",
|
||||
vertical_indices=inter_vertical_indices,
|
||||
slash_indices=inter_slash_indices,
|
||||
@ -1151,7 +1145,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
value_states: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
causal: bool = True,
|
||||
block_table: torch.Tensor = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
stage: str = "intra",
|
||||
vertical_indices: Optional[torch.Tensor] = None,
|
||||
@ -1230,7 +1223,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
device=query_states.device),
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
causal=causal,
|
||||
block_table=block_table.unsqueeze(0),
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user