mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 11:27:29 +08:00
[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel (#20308)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
parent
a6d795d593
commit
22dd9c2730
@ -145,7 +145,19 @@ def kernel_unified_attention_2d(
|
||||
mask=query_mask_1,
|
||||
other=0.0)
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles (blocks) that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, blocks beyond
|
||||
# this prefix can be skipped)
|
||||
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user