This commit is contained in:
Woosuk Kwon 2024-04-26 05:30:08 +00:00
parent 5ae2f81c2b
commit d830766c0c

View File

@ -3,21 +3,23 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
def paged_attn( def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size] q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float, sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch] block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch] context_lens: jax.Array, # [batch]
block_size: int = 16, # FIXME(woosuk) block_size: int = 16,
) -> jax.Array: # [batch, 1, num_heads, head_size] ) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1) q = q.squeeze(1)
q = q * sm_scale q = q * sm_scale
head_size = q.shape[-1] head_size = q.shape[-1]
num_slots = k_cache.shape[-2] num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size) k_cache = k_cache.reshape(-1, num_slots // block_size, block_size,
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size) head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size,
head_size)
output = paged_attention( output = paged_attention(
q, q,