mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 11:36:59 +08:00
Tune pages_per_compute_block
This commit is contained in:
parent
278e8a1adc
commit
408ff4950c
@ -11,8 +11,8 @@ def paged_attn(
|
|||||||
context_lens: jax.Array, # [batch]
|
context_lens: jax.Array, # [batch]
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
q = q.squeeze(1)
|
|
||||||
q = q * sm_scale
|
q = q * sm_scale
|
||||||
|
q = q.squeeze(1)
|
||||||
|
|
||||||
head_size = q.shape[-1]
|
head_size = q.shape[-1]
|
||||||
num_slots = k_cache.shape[-2]
|
num_slots = k_cache.shape[-2]
|
||||||
@ -27,6 +27,6 @@ def paged_attn(
|
|||||||
v_cache,
|
v_cache,
|
||||||
context_lens,
|
context_lens,
|
||||||
block_tables,
|
block_tables,
|
||||||
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
|
pages_per_compute_block=512 // 16, # TODO(woosuk): Tune this value.
|
||||||
)
|
)
|
||||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user