diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bf4a05daf2d5a..543e8487e28b8 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +# These are the 2 tunable parameters of the paged attention Pallas kernel. NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 @@ -154,6 +155,9 @@ class PallasAttentionBackendImpl(AttentionImpl): write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale + # use_kernel switches between using kernel or reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). + use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -164,7 +168,7 @@ class PallasAttentionBackendImpl(AttentionImpl): attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=False, + use_kernel=use_kernel, ) return output.reshape(num_tokens, hidden_size)