mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[FIX] Fix Alibi implementation in PagedAttention kernel (#945)
* [FIX] Fix Alibi implementation in PagedAttention kernel * Fix test_attention * Fix --------- Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Oliver-ss <yuansongwx@outlook.com>
This commit is contained in:
parent
c957c741d9
commit
db09d4ad83
@ -178,7 +178,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
|
||||
@ -17,7 +17,7 @@ NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
USE_ALIBI = [False] # TODO(woosuk): Add USE_ALIBI=True
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ def ref_single_query_cached_kv_attention(
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device="cuda").int()
|
||||
alibi_bias = (context_len - position_ids).float()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
@ -224,6 +224,7 @@ def ref_multi_query_kv_attention(
|
||||
return ref_output
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user