Add block size

This commit is contained in:
Woosuk Kwon 2024-04-27 22:35:28 +00:00
parent 98a3df0f8d
commit 881b884046

View File

@ -10,7 +10,7 @@ BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_SEQ = 512
@functools.partial(jax.jit, static_argnums=(6,))
@functools.partial(jax.jit, static_argnums=(6, 7))
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
@ -18,6 +18,7 @@ def paged_attn(
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int,
pages_per_compute_block: int,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
@ -25,8 +26,8 @@ def paged_attn(
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
v_cache = v_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
output = paged_attention(
q,
@ -46,23 +47,24 @@ def benchmark_paged_attn(
head_size: int,
context_len: int,
num_blocks: int,
block_size: int,
pages_per_compute_block: int,
):
rng_key = jax.random.PRNGKey(0)
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
sm_scale = BLOCK_SIZE ** -0.5
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
sm_scale = head_size ** -0.5
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
# For JIT compilation.
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
start = time.time()
for _ in range(100):
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
end = time.time()
@ -76,20 +78,24 @@ if __name__ == "__main__":
parser.add_argument("--num-kv-heads", type=int, default=16)
parser.add_argument("--head-size", type=int, default=256)
parser.add_argument("--context-len", type=int, default=512)
parser.add_argument("--num-blocks", type=int, default=2048)
args = parser.parse_args()
print(args)
for num_blocks in [2048]:
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
for block_size in [16, 32, 64, 128]:
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
continue
print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}")
if block_size * pages_per_compute_block > 1024:
continue
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
benchmark_paged_attn(
args.batch_size,
args.num_heads,
args.num_kv_heads,
args.head_size,
args.context_len,
num_blocks,
args.num_blocks,
block_size,
pages_per_compute_block,
)