mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
Add buffer donation to benchmark
This commit is contained in:
parent
881b884046
commit
c00ddd6834
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ def write_to_kv_cache1(
|
|||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||||
def write_to_kv_cache2(
|
def write_to_kv_cache2(
|
||||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||||
@ -60,6 +61,7 @@ def write_to_kv_cache2(
|
|||||||
return iterator.k_cache, iterator.v_cache
|
return iterator.k_cache, iterator.v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||||
def _write_seq_to_kv_cache(
|
def _write_seq_to_kv_cache(
|
||||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||||
@ -129,18 +131,18 @@ def benchmark_write_to_kv_cache(
|
|||||||
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
||||||
|
|
||||||
# For JIT compilation.
|
# For JIT compilation.
|
||||||
k, v = f(key, value, k_cache, v_cache, slot_mapping)
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||||
k.block_until_ready()
|
k_cache.block_until_ready()
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
k, v = f(key, value, k_cache, v_cache, slot_mapping)
|
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||||
k.block_until_ready()
|
k_cache.block_until_ready()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for num_blocks in [16, 256, 512, 1024, 2048]:
|
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
||||||
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
||||||
benchmark_write_to_kv_cache(1, 1024, 16, 256, num_blocks, 16, version=1)
|
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user