Tune pages_per_compute_block

This commit is contained in:
Woosuk Kwon 2024-04-26 08:55:23 +00:00
parent 278e8a1adc
commit 408ff4950c

View File

@ -11,8 +11,8 @@ def paged_attn(
context_lens: jax.Array, # [batch]
block_size: int = 16,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale
q = q.squeeze(1)
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
@ -27,6 +27,6 @@ def paged_attn(
v_cache,
context_lens,
block_tables,
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value.
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])