mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 23:04:35 +08:00
[v1] Add comments to the new ragged paged attention Pallas kernel (#14155)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
cd1d3c3df8
commit
79e4937c65
@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
# These are the 2 tunable parameters of the paged attention Pallas kernel.
|
||||
NUM_QUERIES_PER_BLOCK = 16
|
||||
NUM_KV_PAGES_PER_BLOCK = 128
|
||||
|
||||
@ -154,6 +155,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
# use_kernel switches between using kernel or reference implementation
|
||||
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
|
||||
use_kernel = False
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
@ -164,7 +168,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.num_seqs,
|
||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
use_kernel=False,
|
||||
use_kernel=use_kernel,
|
||||
)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user