mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 23:34:27 +08:00
[Attention] Optimize make_local_attention_virtual_batches for Flash Attention (#23185)
Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
parent
64ab3c7253
commit
a634733f67
@ -464,8 +464,9 @@ def make_local_attention_virtual_batches(
|
|||||||
attn_chunk_size)[arange > 0]
|
attn_chunk_size)[arange > 0]
|
||||||
|
|
||||||
# convert from q_seqlens to cu_seqlens_q
|
# convert from q_seqlens to cu_seqlens_q
|
||||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
|
||||||
.astype(np.int32)
|
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
|
||||||
|
cu_seqlens_q_local[0] = 0
|
||||||
|
|
||||||
# compute the seqlens_k_local,
|
# compute the seqlens_k_local,
|
||||||
# basically a full local attention block for all but the last block in each
|
# basically a full local attention block for all but the last block in each
|
||||||
@ -508,11 +509,10 @@ def make_local_attention_virtual_batches(
|
|||||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||||
# ]
|
# ]
|
||||||
block_indices= np.broadcast_to(
|
block_indices = (block_starts[:, None] +
|
||||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
np.arange(pages_per_local_batch, dtype=np.int32))
|
||||||
(virtual_batches, pages_per_local_batch)) \
|
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] -
|
||||||
+ np.expand_dims(block_starts, axis=1)
|
1)
|
||||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
|
||||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||||
local_blocks * pages_per_local_batch)
|
local_blocks * pages_per_local_batch)
|
||||||
block_table_local = block_table[batch_indices, block_indices]\
|
block_table_local = block_table[batch_indices, block_indices]\
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user