mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 06:37:04 +08:00
Fix paged_attn
This commit is contained in:
parent
186c88c497
commit
7e3a230c38
@ -190,6 +190,7 @@ class Attention(nn.Module):
|
|||||||
query_proj,
|
query_proj,
|
||||||
cache[0],
|
cache[0],
|
||||||
cache[1],
|
cache[1],
|
||||||
|
self.sm_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
context_lens,
|
context_lens,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,10 +6,12 @@ def paged_attn(
|
|||||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
k_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
v_cache: jax.Array, # [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
sm_scale: float,
|
||||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
context_lens: jax.Array, # [batch]
|
context_lens: jax.Array, # [batch]
|
||||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
|
q = q * sm_scale
|
||||||
output = paged_attention(
|
output = paged_attention(
|
||||||
q,
|
q,
|
||||||
k_cache,
|
k_cache,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user