mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 07:27:00 +08:00
yapf
This commit is contained in:
parent
5ae2f81c2b
commit
d830766c0c
@ -3,21 +3,23 @@ from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
|||||||
|
|
||||||
|
|
||||||
def paged_attn(
|
def paged_attn(
|
||||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
context_lens: jax.Array, # [batch]
|
context_lens: jax.Array, # [batch]
|
||||||
block_size: int = 16, # FIXME(woosuk)
|
block_size: int = 16,
|
||||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
q = q * sm_scale
|
q = q * sm_scale
|
||||||
|
|
||||||
head_size = q.shape[-1]
|
head_size = q.shape[-1]
|
||||||
num_slots = k_cache.shape[-2]
|
num_slots = k_cache.shape[-2]
|
||||||
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size,
|
||||||
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
head_size)
|
||||||
|
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size,
|
||||||
|
head_size)
|
||||||
|
|
||||||
output = paged_attention(
|
output = paged_attention(
|
||||||
q,
|
q,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user