mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 03:10:11 +08:00
79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
import time
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
|
|
|
BLOCK_SIZE = 16
|
|
|
|
@jax.jit
|
|
def paged_attn(
|
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
|
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
|
sm_scale: float,
|
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
|
context_lens: jax.Array, # [batch]
|
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
|
q = q.squeeze(1)
|
|
q = q * sm_scale
|
|
|
|
head_size = q.shape[-1]
|
|
num_slots = k_cache.shape[-2]
|
|
k_cache = k_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
|
|
v_cache = v_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
|
|
|
|
output = paged_attention(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
context_lens,
|
|
block_tables,
|
|
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
|
|
)
|
|
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
|
|
|
|
|
def benchmark_paged_attn(
|
|
batch_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
context_len: int,
|
|
num_blocks: int,
|
|
):
|
|
rng_key = jax.random.PRNGKey(0)
|
|
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
|
sm_scale = BLOCK_SIZE ** -0.5
|
|
block_tables = jax.random.randint(rng_key, (batch_size, context_len // BLOCK_SIZE), 0, num_blocks, dtype=jnp.int32)
|
|
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
|
|
|
# For JIT compilation.
|
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens)
|
|
output.block_until_ready()
|
|
|
|
start = time.time()
|
|
for _ in range(100):
|
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens)
|
|
output.block_until_ready()
|
|
end = time.time()
|
|
|
|
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
for num_blocks in [16, 256, 512, 2048]:
|
|
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks")
|
|
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks)
|
|
|
|
# BUG: This will raise the following error:
|
|
# jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid:
|
|
# INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure.
|
|
# Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0xad3 (from TensorCoreSequencer:1:0xad4):
|
|
# no debugging message found for this tag:pc. HLO: custom-call.2; HLO computation: main.55
|
|
num_blocks = 1024
|
|
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks")
|
|
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks)
|