[v1] Add comments to the new ragged paged attention Pallas kernel (#14155)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
iefgnoix 2025-03-03 15:00:55 -08:00 committed by GitHub
parent cd1d3c3df8
commit 79e4937c65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 128
@ -154,6 +155,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
# use_kernel switches between using kernel or reference implementation
# (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890).
use_kernel = False
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
@ -164,7 +168,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=False,
use_kernel=use_kernel,
)
return output.reshape(num_tokens, hidden_size)