Add buffer donation to benchmark

This commit is contained in:
Woosuk Kwon 2024-04-30 21:58:47 +00:00
parent 881b884046
commit c00ddd6834

View File

@ -1,3 +1,4 @@
import functools
import time import time
from typing import Tuple from typing import Tuple
@ -29,7 +30,7 @@ def write_to_kv_cache1(
return k_cache, v_cache return k_cache, v_cache
@jax.jit @functools.partial(jax.jit, donate_argnums=(2, 3))
def write_to_kv_cache2( def write_to_kv_cache2(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size] key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
@ -60,6 +61,7 @@ def write_to_kv_cache2(
return iterator.k_cache, iterator.v_cache return iterator.k_cache, iterator.v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def _write_seq_to_kv_cache( def _write_seq_to_kv_cache(
key: jax.Array, # [seq_len, num_heads, head_size] key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size] value: jax.Array, # [seq_len, num_heads, head_size]
@ -129,18 +131,18 @@ def benchmark_write_to_kv_cache(
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32) slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
# For JIT compilation. # For JIT compilation.
k, v = f(key, value, k_cache, v_cache, slot_mapping) k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k.block_until_ready() k_cache.block_until_ready()
start = time.time() start = time.time()
for _ in range(100): for _ in range(100):
k, v = f(key, value, k_cache, v_cache, slot_mapping) k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k.block_until_ready() k_cache.block_until_ready()
end = time.time() end = time.time()
print(f"Time taken: {(end - start) * 10:.2f} ms") print(f"Time taken: {(end - start) * 10:.2f} ms")
if __name__ == "__main__": if __name__ == "__main__":
for num_blocks in [16, 256, 512, 1024, 2048]: for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks") print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1) benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)