mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 05:59:11 +08:00
Fix paged_attn
This commit is contained in:
parent
186c88c497
commit
7e3a230c38
@ -190,6 +190,7 @@ class Attention(nn.Module):
|
||||
query_proj,
|
||||
cache[0],
|
||||
cache[1],
|
||||
self.sm_scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
|
||||
@ -6,10 +6,12 @@ def paged_attn(
|
||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||
sm_scale: float,
|
||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||
context_lens: jax.Array, # [batch]
|
||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||
q = q.squeeze(1)
|
||||
q = q * sm_scale
|
||||
output = paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user