From 79e4937c65d5f553f878293a0da50f83b3773141 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Mon, 3 Mar 2025 15:00:55 -0800 Subject: [PATCH] [v1] Add comments to the new ragged paged attention Pallas kernel (#14155) Signed-off-by: Xiongfei Wei Co-authored-by: Michael Goin --- vllm/v1/attention/backends/pallas.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bf4a05daf2d5a..543e8487e28b8 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +# These are the 2 tunable parameters of the paged attention Pallas kernel. NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 @@ -154,6 +155,9 @@ class PallasAttentionBackendImpl(AttentionImpl): write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale + # use_kernel switches between using kernel or reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). + use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -164,7 +168,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=False, + use_kernel=use_kernel, ) return output.reshape(num_tokens, hidden_size)