From 07be6ed3eb6c7f78603cbf63ba19cca8dd3e46c6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Apr 2024 08:54:41 +0000 Subject: [PATCH] Improve benchmark --- benchmarks/bench_paged_attn.py | 53 ++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/benchmarks/bench_paged_attn.py b/benchmarks/bench_paged_attn.py index 2ff27622fd6cb..42bb949d56cbc 100644 --- a/benchmarks/bench_paged_attn.py +++ b/benchmarks/bench_paged_attn.py @@ -1,11 +1,16 @@ +import argparse +import functools import time + import jax import jax.numpy as jnp from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention BLOCK_SIZE = 16 +MAX_NUM_BLOCKS_PER_SEQ = 512 -@jax.jit + +@functools.partial(jax.jit, static_argnums=(6,)) def paged_attn( q: jax.Array, # [batch, 1, num_heads, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] @@ -13,6 +18,7 @@ def paged_attn( sm_scale: float, block_tables: jax.Array, # [batch, max_num_blocks_per_batch] context_lens: jax.Array, # [batch] + pages_per_compute_block: int, ) -> jax.Array: # [batch, 1, num_heads, head_size] q = q.squeeze(1) q = q * sm_scale @@ -28,7 +34,7 @@ def paged_attn( v_cache, context_lens, block_tables, - pages_per_compute_block=4, # TODO(woosuk): Tune this value. + pages_per_compute_block=pages_per_compute_block, ) return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2]) @@ -40,39 +46,50 @@ def benchmark_paged_attn( head_size: int, context_len: int, num_blocks: int, + pages_per_compute_block: int, ): rng_key = jax.random.PRNGKey(0) query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16) k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) sm_scale = BLOCK_SIZE ** -0.5 - block_tables = jax.random.randint(rng_key, (batch_size, context_len // BLOCK_SIZE), 0, num_blocks, dtype=jnp.int32) + block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32) context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32) # For JIT compilation. - output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block) output.block_until_ready() start = time.time() for _ in range(100): - output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens) + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block) output.block_until_ready() end = time.time() - print(f"Time taken: {(end - start) * 10:.2f} ms") + print(f"Time taken: {(end - start) * 10000:.2f} us") if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--num-heads", type=int, default=16) + parser.add_argument("--num-kv-heads", type=int, default=16) + parser.add_argument("--head-size", type=int, default=256) + parser.add_argument("--context-len", type=int, default=512) + args = parser.parse_args() + print(args) - for num_blocks in [16, 256, 512, 2048]: - print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") - benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks) - - # BUG: This will raise the following error: - # jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid: - # INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. - # Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0xad3 (from TensorCoreSequencer:1:0xad4): - # no debugging message found for this tag:pc. HLO: custom-call.2; HLO computation: main.55 - num_blocks = 1024 - print(f"Benchmarking Paged Attention w/ {num_blocks} blocks") - benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks) + for num_blocks in [2048]: + for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: + if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ: + continue + print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}") + benchmark_paged_attn( + args.batch_size, + args.num_heads, + args.num_kv_heads, + args.head_size, + args.context_len, + num_blocks, + pages_per_compute_block, + )