Fix paged_attn

This commit is contained in:
Woosuk Kwon 2024-04-17 20:06:26 +00:00
parent 186c88c497
commit 7e3a230c38
2 changed files with 3 additions and 0 deletions

View File

@ -190,6 +190,7 @@ class Attention(nn.Module):
query_proj,
cache[0],
cache[1],
self.sm_scale,
block_tables,
context_lens,
)

View File

@ -6,10 +6,12 @@ def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale
output = paged_attention(
q,
k_cache,