Add buffer donation to benchmark

This commit is contained in:
Woosuk Kwon 2024-04-30 21:58:47 +00:00
parent 881b884046
commit c00ddd6834

View File

@ -1,3 +1,4 @@
import functools
import time
from typing import Tuple
@ -29,7 +30,7 @@ def write_to_kv_cache1(
return k_cache, v_cache
@jax.jit
@functools.partial(jax.jit, donate_argnums=(2, 3))
def write_to_kv_cache2(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
@ -60,6 +61,7 @@ def write_to_kv_cache2(
return iterator.k_cache, iterator.v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def _write_seq_to_kv_cache(
key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size]
@ -129,18 +131,18 @@ def benchmark_write_to_kv_cache(
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
# For JIT compilation.
k, v = f(key, value, k_cache, v_cache, slot_mapping)
k.block_until_ready()
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
start = time.time()
for _ in range(100):
k, v = f(key, value, k_cache, v_cache, slot_mapping)
k.block_until_ready()
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
end = time.time()
print(f"Time taken: {(end - start) * 10:.2f} ms")
if __name__ == "__main__":
for num_blocks in [16, 256, 512, 1024, 2048]:
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1)
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)