mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 22:17:08 +08:00
Add block size
This commit is contained in:
parent
98a3df0f8d
commit
881b884046
@ -10,7 +10,7 @@ BLOCK_SIZE = 16
|
|||||||
MAX_NUM_BLOCKS_PER_SEQ = 512
|
MAX_NUM_BLOCKS_PER_SEQ = 512
|
||||||
|
|
||||||
|
|
||||||
@functools.partial(jax.jit, static_argnums=(6,))
|
@functools.partial(jax.jit, static_argnums=(6, 7))
|
||||||
def paged_attn(
|
def paged_attn(
|
||||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||||
@ -18,6 +18,7 @@ def paged_attn(
|
|||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||||
context_lens: jax.Array, # [batch]
|
context_lens: jax.Array, # [batch]
|
||||||
|
block_size: int,
|
||||||
pages_per_compute_block: int,
|
pages_per_compute_block: int,
|
||||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
@ -25,8 +26,8 @@ def paged_attn(
|
|||||||
|
|
||||||
head_size = q.shape[-1]
|
head_size = q.shape[-1]
|
||||||
num_slots = k_cache.shape[-2]
|
num_slots = k_cache.shape[-2]
|
||||||
k_cache = k_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
|
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||||
v_cache = v_cache.reshape(-1, num_slots // BLOCK_SIZE, BLOCK_SIZE, head_size)
|
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||||
|
|
||||||
output = paged_attention(
|
output = paged_attention(
|
||||||
q,
|
q,
|
||||||
@ -46,23 +47,24 @@ def benchmark_paged_attn(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
context_len: int,
|
context_len: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
pages_per_compute_block: int,
|
pages_per_compute_block: int,
|
||||||
):
|
):
|
||||||
rng_key = jax.random.PRNGKey(0)
|
rng_key = jax.random.PRNGKey(0)
|
||||||
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
||||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * BLOCK_SIZE, head_size), dtype=jnp.bfloat16)
|
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||||
sm_scale = BLOCK_SIZE ** -0.5
|
sm_scale = head_size ** -0.5
|
||||||
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
||||||
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
||||||
|
|
||||||
# For JIT compilation.
|
# For JIT compilation.
|
||||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||||
output.block_until_ready()
|
output.block_until_ready()
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, pages_per_compute_block)
|
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||||
output.block_until_ready()
|
output.block_until_ready()
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
@ -76,20 +78,24 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--num-kv-heads", type=int, default=16)
|
parser.add_argument("--num-kv-heads", type=int, default=16)
|
||||||
parser.add_argument("--head-size", type=int, default=256)
|
parser.add_argument("--head-size", type=int, default=256)
|
||||||
parser.add_argument("--context-len", type=int, default=512)
|
parser.add_argument("--context-len", type=int, default=512)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=2048)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
for num_blocks in [2048]:
|
for block_size in [16, 32, 64, 128]:
|
||||||
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
|
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
|
||||||
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
||||||
continue
|
continue
|
||||||
print(f"num_blocks: {num_blocks}, pages_per_compute_block: {pages_per_compute_block}")
|
if block_size * pages_per_compute_block > 1024:
|
||||||
|
continue
|
||||||
|
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
|
||||||
benchmark_paged_attn(
|
benchmark_paged_attn(
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
args.num_heads,
|
args.num_heads,
|
||||||
args.num_kv_heads,
|
args.num_kv_heads,
|
||||||
args.head_size,
|
args.head_size,
|
||||||
args.context_len,
|
args.context_len,
|
||||||
num_blocks,
|
args.num_blocks,
|
||||||
|
block_size,
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user