mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-14 05:47:59 +08:00
[perf] Use direct copy (broadcast) instead of cat for k_nope/k_pe in MLA prefill (#29710)
Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
parent
d02d1043de
commit
fba8906930
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||||
|
in MLA (Multi-head Latent Attention) prefill.
|
||||||
|
|
||||||
|
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||||
|
various batch sizes, not just the originally tested batch size of 32768.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# DeepSeek-V3 MLA dimensions
|
||||||
|
NUM_HEADS = 128
|
||||||
|
QK_NOPE_HEAD_DIM = 128
|
||||||
|
PE_DIM = 64
|
||||||
|
|
||||||
|
|
||||||
|
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Original torch.cat approach with expand."""
|
||||||
|
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||||
|
k = torch.empty(
|
||||||
|
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||||
|
dtype=k_nope.dtype,
|
||||||
|
device=k_nope.device,
|
||||||
|
)
|
||||||
|
k[..., : k_nope.shape[-1]] = k_nope
|
||||||
|
k[..., k_nope.shape[-1] :] = k_pe
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_method(
|
||||||
|
method: Callable,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
num_warmup: int = 10,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||||
|
# Warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
return (end - start) / num_iters * 1000 # Convert to ms
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||||
|
"""Run benchmark for a specific dtype."""
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||||
|
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||||
|
print("=" * 80)
|
||||||
|
print(
|
||||||
|
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||||
|
f"k_pe=[B, 1, {PE_DIM}]"
|
||||||
|
)
|
||||||
|
print(f"dtype: {dtype_name}")
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||||
|
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||||
|
)
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||||
|
k_nope = torch.randn(
|
||||||
|
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
k_pe = torch.randn(
|
||||||
|
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
|
# Benchmark both methods
|
||||||
|
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||||
|
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||||
|
|
||||||
|
speedup = cat_time / direct_time
|
||||||
|
reduction = (1 - direct_time / cat_time) * 100
|
||||||
|
|
||||||
|
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||||
|
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
speedups = [r[3] for r in results]
|
||||||
|
print("\nSpeedup summary:")
|
||||||
|
print(f" Min: {min(speedups):.2f}x")
|
||||||
|
print(f" Max: {max(speedups):.2f}x")
|
||||||
|
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||||
|
|
||||||
|
# Find crossover point
|
||||||
|
crossover_batch = None
|
||||||
|
for batch_size, _, _, speedup, _ in results:
|
||||||
|
if speedup >= 1.0:
|
||||||
|
crossover_batch = batch_size
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\nConclusion:")
|
||||||
|
if crossover_batch:
|
||||||
|
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||||
|
# Filter for large batches (>= 512 which is typical for prefill)
|
||||||
|
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||||
|
if large_batch_speedups:
|
||||||
|
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||||
|
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||||
|
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
# Test bfloat16
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.bfloat16, "bfloat16")
|
||||||
|
|
||||||
|
# Test float8_e4m3fn
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1654,6 +1654,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# Convert from (L, N, P) to (N, P, L)
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||||
|
|
||||||
|
def _concat_k_nope_k_pe(
|
||||||
|
self, k_nope: torch.Tensor, k_pe: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Efficiently concatenate k_nope and k_pe tensors along the last dimension.
|
||||||
|
|
||||||
|
This function avoids the performance penalty of torch.cat with expanded
|
||||||
|
non-contiguous tensors by pre-allocating the output and using direct copies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k_nope: Tensor of shape [..., nope_dim]
|
||||||
|
k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
|
||||||
|
or [..., pe_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [..., nope_dim + pe_dim]
|
||||||
|
"""
|
||||||
|
k = torch.empty(
|
||||||
|
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||||
|
dtype=k_nope.dtype,
|
||||||
|
device=k_nope.device,
|
||||||
|
)
|
||||||
|
# Direct copies with efficient broadcasting
|
||||||
|
k[..., : k_nope.shape[-1]] = k_nope
|
||||||
|
k[..., k_nope.shape[-1] :] = k_pe
|
||||||
|
return k
|
||||||
|
|
||||||
def _compute_prefill_context(
|
def _compute_prefill_context(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@ -1690,7 +1717,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||||
|
|
||||||
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
||||||
prefill=prefill_metadata,
|
prefill=prefill_metadata,
|
||||||
@ -1794,7 +1821,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||||
)
|
)
|
||||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||||
|
|
||||||
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
||||||
prefill=prefill_metadata,
|
prefill=prefill_metadata,
|
||||||
@ -1843,7 +1870,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
)
|
)
|
||||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||||
|
|
||||||
output_prefill = self._run_prefill_new_tokens(
|
output_prefill = self._run_prefill_new_tokens(
|
||||||
prefill=attn_metadata.prefill,
|
prefill=attn_metadata.prefill,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user