From c00ddd68343cbcf23811c1c121d0faa2e8f43bfc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 30 Apr 2024 21:58:47 +0000 Subject: [PATCH] Add buffer donation to benchmark --- benchmarks/bench_cache_write.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_cache_write.py b/benchmarks/bench_cache_write.py index 2c2ba318bfe8f..7229d758ba425 100644 --- a/benchmarks/bench_cache_write.py +++ b/benchmarks/bench_cache_write.py @@ -1,3 +1,4 @@ +import functools import time from typing import Tuple @@ -29,7 +30,7 @@ def write_to_kv_cache1( return k_cache, v_cache -@jax.jit +@functools.partial(jax.jit, donate_argnums=(2, 3)) def write_to_kv_cache2( key: jax.Array, # [batch_size, seq_len, num_heads, head_size] value: jax.Array, # [batch_size, seq_len, num_heads, head_size] @@ -60,6 +61,7 @@ def write_to_kv_cache2( return iterator.k_cache, iterator.v_cache +@functools.partial(jax.jit, donate_argnums=(2, 3)) def _write_seq_to_kv_cache( key: jax.Array, # [seq_len, num_heads, head_size] value: jax.Array, # [seq_len, num_heads, head_size] @@ -129,18 +131,18 @@ def benchmark_write_to_kv_cache( slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32) # For JIT compilation. - k, v = f(key, value, k_cache, v_cache, slot_mapping) - k.block_until_ready() + k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping) + k_cache.block_until_ready() start = time.time() for _ in range(100): - k, v = f(key, value, k_cache, v_cache, slot_mapping) - k.block_until_ready() + k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping) + k_cache.block_until_ready() end = time.time() print(f"Time taken: {(end - start) * 10:.2f} ms") if __name__ == "__main__": - for num_blocks in [16, 256, 512, 1024, 2048]: + for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]: print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks") - benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1) + benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)