Improve benchmark

This commit is contained in:
Woosuk Kwon 2024-04-26 08:54:41 +00:00
parent f6637dba18
commit 07be6ed3eb

View File

@ -1,11 +1,16 @@
import argparse
import functools
import time import time
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
BLOCK_SIZE = 16 BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_SEQ = 512
@jax.jit
@functools.partial(jax.jit, static_argnums=(6,))
def paged_attn( def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size] q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
@ -13,6 +18,7 @@ def paged_attn(
sm_scale: float, sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch] block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch] context_lens: jax.Array, # [batch]
pages_per_compute_block: int,
) -> jax.Array: # [batch, 1, num_heads, head_size] ) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1) q = q.squeeze(1)
q = q * sm_scale q = q * sm_scale
@ -28,7 +34,7 @@ def paged_attn(
v_cache, v_cache,
context_lens, context_lens,
block_tables, block_tables,
pages_per_compute_block=4, # TODO(woosuk): Tune this value. pages_per_compute_block=pages_per_compute_block,
) )
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
@ -40,39 +46,50 @@ def benchmark_paged_attn(
head_size: int, head_size: int,
context_len: int, context_len: int,
num_blocks: int, num_blocks: int,
pages_per_compute_block: int,
): ):
rng_key = jax.random.PRNGKey(0) rng_key = jax.random.PRNGKey(0)
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16) query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
sm_scale = BLOCK_SIZE ** -0.5 sm_scale = BLOCK_SIZE ** -0.5
block_tables = jax.random.randint(rng_key, (batch_size, context_len // BLOCK_SIZE), 0, num_blocks, dtype=jnp.int32) block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32) context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
# For JIT compilation. # For JIT compilation.
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
output.block_until_ready() output.block_until_ready()
start = time.time() start = time.time()
for _ in range(100): for _ in range(100):
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
output.block_until_ready() output.block_until_ready()
end = time.time() end = time.time()
print(f"Time taken: {(end - start) * 10:.2f} ms") print(f"Time taken: {(end - start) * 10000:.2f} us")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--num-kv-heads", type=int, default=16)
parser.add_argument("--head-size", type=int, default=256)
parser.add_argument("--context-len", type=int, default=512)
args = parser.parse_args()
print(args)
for num_blocks in [16, 256, 512, 2048]: for num_blocks in [2048]:
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks) if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
continue
# BUG: This will raise the following error: print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}")
# jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid: benchmark_paged_attn(
# INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. args.batch_size,
# Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0xad3 (from TensorCoreSequencer:1:0xad4): args.num_heads,
# no debugging message found for this tag:pc. HLO: custom-call.2; HLO computation: main.55 args.num_kv_heads,
num_blocks = 1024 args.head_size,
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") args.context_len,
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks) num_blocks,
pages_per_compute_block,
)