mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 17:42:14 +08:00
Improve benchmark
This commit is contained in:
parent
f6637dba18
commit
07be6ed3eb
@ -1,11 +1,16 @@
|
|||||||
|
import argparse
|
||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
MAX_NUM_BLOCKS_PER_SEQ = 512
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
|
@functools.partial(jax.jit, static_argnums=(6,))
|
||||||
def paged_attn(
|
def paged_attn(
|
||||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
@ -13,6 +18,7 @@ def paged_attn(
|
|||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
context_lens: jax.Array, # [batch]
|
context_lens: jax.Array, # [batch]
|
||||||
|
pages_per_compute_block: int,
|
||||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
q = q * sm_scale
|
q = q * sm_scale
|
||||||
@ -28,7 +34,7 @@ def paged_attn(
|
|||||||
v_cache,
|
v_cache,
|
||||||
context_lens,
|
context_lens,
|
||||||
block_tables,
|
block_tables,
|
||||||
pages_per_compute_block=4, # TODO(woosuk): Tune this value.
|
pages_per_compute_block=pages_per_compute_block,
|
||||||
)
|
)
|
||||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||||
|
|
||||||
@ -40,39 +46,50 @@ def benchmark_paged_attn(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
context_len: int,
|
context_len: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
pages_per_compute_block: int,
|
||||||
):
|
):
|
||||||
rng_key = jax.random.PRNGKey(0)
|
rng_key = jax.random.PRNGKey(0)
|
||||||
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
||||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
||||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
||||||
sm_scale = BLOCK_SIZE ** -0.5
|
sm_scale = BLOCK_SIZE ** -0.5
|
||||||
block_tables = jax.random.randint(rng_key, (batch_size, context_len // BLOCK_SIZE), 0, num_blocks, dtype=jnp.int32)
|
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
||||||
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
||||||
|
|
||||||
# For JIT compilation.
|
# For JIT compilation.
|
||||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens)
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
|
||||||
output.block_until_ready()
|
output.block_until_ready()
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens)
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
|
||||||
output.block_until_ready()
|
output.block_until_ready()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
print(f"Time taken: {(end - start) * 10000:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--num-heads", type=int, default=16)
|
||||||
|
parser.add_argument("--num-kv-heads", type=int, default=16)
|
||||||
|
parser.add_argument("--head-size", type=int, default=256)
|
||||||
|
parser.add_argument("--context-len", type=int, default=512)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
for num_blocks in [16, 256, 512, 2048]:
|
for num_blocks in [2048]:
|
||||||
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks")
|
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
||||||
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks)
|
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
||||||
|
continue
|
||||||
# BUG: This will raise the following error:
|
print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}")
|
||||||
# jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid:
|
benchmark_paged_attn(
|
||||||
# INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure.
|
args.batch_size,
|
||||||
# Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0xad3 (from TensorCoreSequencer:1:0xad4):
|
args.num_heads,
|
||||||
# no debugging message found for this tag:pc. HLO: custom-call.2; HLO computation: main.55
|
args.num_kv_heads,
|
||||||
num_blocks = 1024
|
args.head_size,
|
||||||
print(f"Benchmarking Paged Attention w/ {num_blocks} blocks")
|
args.context_len,
|
||||||
benchmark_paged_attn(1, 16, 16, 256, 128, num_blocks)
|
num_blocks,
|
||||||
|
pages_per_compute_block,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user