mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:47:09 +08:00
Add buffer donation to benchmark
This commit is contained in:
parent
881b884046
commit
c00ddd6834
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@ -29,7 +30,7 @@ def write_to_kv_cache1(
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
@jax.jit
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def write_to_kv_cache2(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
@ -60,6 +61,7 @@ def write_to_kv_cache2(
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def _write_seq_to_kv_cache(
|
||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||
@ -129,18 +131,18 @@ def benchmark_write_to_kv_cache(
|
||||
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
||||
|
||||
# For JIT compilation.
|
||||
k, v = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k.block_until_ready()
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
k, v = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k.block_until_ready()
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
end = time.time()
|
||||
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for num_blocks in [16, 256, 512, 1024, 2048]:
|
||||
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
||||
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
||||
benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1)
|
||||
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user