From 408ff4950c2bec008dad55e8eda6b40217cc5ce6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 08:55:23 +0000 Subject: [PATCH] Tune pages_per_compute_block --- vllm/model_executor/models/jax/ops/paged_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jax/ops/paged_attn.py b/vllm/model_executor/models/jax/ops/paged_attn.py index f85445ecf3598..4d2480c4cae7a 100644 --- a/vllm/model_executor/models/jax/ops/paged_attn.py +++ b/vllm/model_executor/models/jax/ops/paged_attn.py @@ -11,8 +11,8 @@ def paged_attn( context_lens: jax.Array, # [batch] block_size: int = 16, ) -> jax.Array: # [batch, 1, num_heads, head_size] - q = q.squeeze(1) q = q * sm_scale + q = q.squeeze(1) head_size = q.shape[-1] num_slots = k_cache.shape[-2] @@ -27,6 +27,6 @@ def paged_attn( v_cache, context_lens, block_tables, - pages_per_compute_block=4, # TODO(woosuk): Tune this value. + pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value. ) return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])