diff --git a/benchmarks/bench_paged_attn.py b/benchmarks/bench_paged_attn.py index 42bb949d56cbc..b752252fc05f8 100644 --- a/benchmarks/bench_paged_attn.py +++ b/benchmarks/bench_paged_attn.py @@ -10,7 +10,7 @@ BLOCK_SIZE = 16 MAX_NUM_BLOCKS_PER_SEQ = 512 -@functools.partial(jax.jit, static_argnums=(6,)) +@functools.partial(jax.jit, static_argnums=(6, 7)) def paged_attn( q: jax.Array, # [batch, 1, num_heads, head_size] k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size] @@ -18,6 +18,7 @@ def paged_attn( sm_scale: float, block_tables: jax.Array, # [batch, max_num_blocks_per_batch] context_lens: jax.Array, # [batch] + block_size: int, pages_per_compute_block: int, ) -> jax.Array: # [batch, 1, num_heads, head_size] q = q.squeeze(1) @@ -25,8 +26,8 @@ def paged_attn( head_size = q.shape[-1] num_slots = k_cache.shape[-2] - k_cache = k_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size) - v_cache = v_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size) + k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size) + v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size) output = paged_attention( q, @@ -46,23 +47,24 @@ def benchmark_paged_attn( head_size: int, context_len: int, num_blocks: int, + block_size: int, pages_per_compute_block: int, ): rng_key = jax.random.PRNGKey(0) query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16) - k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) - v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16) - sm_scale = BLOCK_SIZE ** -0.5 + k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16) + sm_scale = head_size ** -0.5 block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32) context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32) # For JIT compilation. - output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block) + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block) output.block_until_ready() start = time.time() for _ in range(100): - output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block) + output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block) output.block_until_ready() end = time.time() @@ -76,20 +78,24 @@ if __name__ == "__main__": parser.add_argument("--num-kv-heads", type=int, default=16) parser.add_argument("--head-size", type=int, default=256) parser.add_argument("--context-len", type=int, default=512) + parser.add_argument("--num-blocks", type=int, default=2048) args = parser.parse_args() print(args) - for num_blocks in [2048]: - for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: + for block_size in [16, 32, 64, 128]: + for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]: if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ: continue - print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}") + if block_size * pages_per_compute_block > 1024: + continue + print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}") benchmark_paged_attn( args.batch_size, args.num_heads, args.num_kv_heads, args.head_size, args.context_len, - num_blocks, + args.num_blocks, + block_size, pages_per_compute_block, )