diff --git a/benchmarks/kernels/benchmark_mla_k_concat.py b/benchmarks/kernels/benchmark_mla_k_concat.py new file mode 100644 index 0000000000000..fb3b6c8f12003 --- /dev/null +++ b/benchmarks/kernels/benchmark_mla_k_concat.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation +in MLA (Multi-head Latent Attention) prefill. + +This validates that the optimization from commit 8d4142bd is beneficial across +various batch sizes, not just the originally tested batch size of 32768. +""" + +import time +from collections.abc import Callable + +import torch + +# DeepSeek-V3 MLA dimensions +NUM_HEADS = 128 +QK_NOPE_HEAD_DIM = 128 +PE_DIM = 64 + + +def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Original torch.cat approach with expand.""" + return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + +def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor: + """Optimized direct copy approach (avoids expand + cat overhead).""" + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + + +def benchmark_method( + method: Callable, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + num_warmup: int = 10, + num_iters: int = 100, +) -> float: + """Benchmark a concatenation method and return mean latency in ms.""" + # Warmup + for _ in range(num_warmup): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + _ = method(k_nope, k_pe) + torch.cuda.synchronize() + end = time.perf_counter() + + return (end - start) / num_iters * 1000 # Convert to ms + + +@torch.inference_mode() +def run_benchmark(dtype: torch.dtype, dtype_name: str): + """Run benchmark for a specific dtype.""" + torch.set_default_device("cuda") + + # Batch sizes to test (powers of 2 from 32 to 65536) + batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536] + + print("=" * 80) + print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation") + print("=" * 80) + print( + f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], " + f"k_pe=[B, 1, {PE_DIM}]" + ) + print(f"dtype: {dtype_name}") + print() + print( + f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | " + f"{'Speedup':>8} | {'Reduction':>10}" + ) + print("-" * 70) + + results = [] + for batch_size in batch_sizes: + # Create input tensors (generate in float32 then convert for FP8 compatibility) + k_nope = torch.randn( + batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + k_pe = torch.randn( + batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda" + ).to(dtype) + + # Benchmark both methods + cat_time = benchmark_method(cat_method, k_nope, k_pe) + direct_time = benchmark_method(direct_copy_method, k_nope, k_pe) + + speedup = cat_time / direct_time + reduction = (1 - direct_time / cat_time) * 100 + + results.append((batch_size, cat_time, direct_time, speedup, reduction)) + + print( + f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | " + f"{speedup:>7.2f}x | {reduction:>9.1f}%" + ) + + print("=" * 80) + + # Summary statistics + speedups = [r[3] for r in results] + print("\nSpeedup summary:") + print(f" Min: {min(speedups):.2f}x") + print(f" Max: {max(speedups):.2f}x") + print(f" Mean: {sum(speedups) / len(speedups):.2f}x") + + # Find crossover point + crossover_batch = None + for batch_size, _, _, speedup, _ in results: + if speedup >= 1.0: + crossover_batch = batch_size + break + + print("\nConclusion:") + if crossover_batch: + print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}") + # Filter for large batches (>= 512 which is typical for prefill) + large_batch_speedups = [r[3] for r in results if r[0] >= 512] + if large_batch_speedups: + avg_large = sum(large_batch_speedups) / len(large_batch_speedups) + print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x") + print(" - MLA prefill typically uses large batches, so optimization is effective") + + return results + + +@torch.inference_mode() +def main(): + # Test bfloat16 + print("\n") + run_benchmark(torch.bfloat16, "bfloat16") + + # Test float8_e4m3fn + print("\n") + run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn") + + +if __name__ == "__main__": + main() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0a5257a1d87d8..8265503c28c35 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1654,6 +1654,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0) + def _concat_k_nope_k_pe( + self, k_nope: torch.Tensor, k_pe: torch.Tensor + ) -> torch.Tensor: + """ + Efficiently concatenate k_nope and k_pe tensors along the last dimension. + + This function avoids the performance penalty of torch.cat with expanded + non-contiguous tensors by pre-allocating the output and using direct copies. + + Args: + k_nope: Tensor of shape [..., nope_dim] + k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim] + or [..., pe_dim] + + Returns: + Tensor of shape [..., nope_dim + pe_dim] + """ + k = torch.empty( + (*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]), + dtype=k_nope.dtype, + device=k_nope.device, + ) + # Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe + return k + def _compute_prefill_context( self, q: torch.Tensor, @@ -1690,7 +1717,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1794,7 +1821,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1843,7 +1870,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill,