From 5b80f22087589ce55009f4423ce9af7b5f2d43df Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Fri, 3 Oct 2025 16:33:46 +0800 Subject: [PATCH] [Perf] Optimize `reshape_and_cache` CUDA Kernel (#25955) Signed-off-by: zjy0516 Co-authored-by: Liu-congo <1502632128@qq.com> Signed-off-by: yewentao256 --- .../kernels/benchmark_reshape_and_cache.py | 174 ++++++++++++++++++ csrc/cache_kernels.cu | 96 +++++----- 2 files changed, 225 insertions(+), 45 deletions(-) create mode 100644 benchmarks/kernels/benchmark_reshape_and_cache.py diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py new file mode 100644 index 0000000000000..af9841daadf24 --- /dev/null +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import random +import time + +import torch +from tabulate import tabulate + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random, +) + +logger = init_logger(__name__) + + +@torch.inference_mode() +def run_benchmark( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + kv_cache_dtype: str, + num_iters: int, + benchmark_mode: str, + device: str = "cuda", +) -> float: + """Return latency (seconds) for given num_tokens.""" + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + + current_platform.seed_everything(42) + torch.set_default_device(device) + + # create random key / value tensors [T, H, D]. + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) + value = torch.randn_like(key) + + # prepare the slot mapping. + # each token is assigned a unique slot in the KV-cache. + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError("num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + key_caches, value_caches = create_kv_caches_with_random( + num_blocks, + block_size, + 1, # num_layers + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + # to free unused memory + del key_caches, value_caches + + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + + function_under_test = lambda: ops.reshape_and_cache( + key, # noqa: F821 + value, # noqa: F821 + key_cache, # noqa: F821 + value_cache, # noqa: F821 + slot_mapping, # noqa: F821 + kv_cache_dtype, + k_scale, + v_scale, + ) + + if benchmark_mode == "cudagraph": + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + function_under_test() + torch.cuda.synchronize() + function_under_test = lambda: g.replay() + + def run_cuda_benchmark(n_iters: int) -> float: + nonlocal key, value, key_cache, value_cache, slot_mapping + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(n_iters): + function_under_test() + torch.cuda.synchronize() + end = time.perf_counter() + return (end - start) / n_iters + + # warm-up + run_cuda_benchmark(3) + + lat = run_cuda_benchmark(num_iters) + + # free tensors to mitigate OOM when sweeping + del key, value, key_cache, value_cache, slot_mapping + torch.cuda.empty_cache() + + return lat + + +def main(args): + rows = [] + for exp in range(1, 17): + n_tok = 2**exp + lat = run_benchmark( + num_tokens=n_tok, + num_heads=args.num_heads, + head_size=args.head_size, + block_size=args.block_size, + num_blocks=args.num_blocks, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + kv_cache_dtype=args.kv_cache_dtype, + num_iters=args.iters, + benchmark_mode=args.mode, + device="cuda", + ) + rows.append([n_tok, lat * 1e6]) # convert to microseconds + + print(f"Benchmark results for implementation cuda (measuring with {args.mode}):") + print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f")) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + + parser.add_argument("--num-heads", type=int, default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--num-blocks", type=int, default=128 * 128) + + parser.add_argument( + "--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="bfloat16", + ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + ) + + parser.add_argument("--iters", type=int, default=200) + + parser.add_argument( + "--mode", + type=str, + choices=["cudagraph", "no_graph"], + default="cudagraph", + ) + + args = parser.parse_args() + + main(args) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index e9d39fc81284b..9a9fb8724c1e4 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -16,8 +16,7 @@ #include #include -#include -#include +#include // FLT_MIN #ifdef USE_ROCM #include @@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector const& kv_caches, namespace vllm { +// Used to copy/convert one element +template +struct CopyWithScaleOp { + float scale; + + __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst = static_cast(src); + } else { + dst = fp8::scaled_convert(src, scale); + } + } +}; + template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -224,59 +237,51 @@ __global__ void reshape_and_cache_kernel( const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { - // Padding token that should be ignored. return; } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; + const int h_block_count = head_size / x; // head_size//x - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int64_t src_key_idx = token_idx * key_stride + i; - const int64_t src_value_idx = token_idx * value_stride + i; + const int h_block_idx = threadIdx.x; + if (h_block_idx >= num_heads * h_block_count) { + return; + } - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; + const int head_idx = h_block_idx / h_block_count; + const int h_block = h_block_idx % h_block_count; - const int64_t tgt_key_idx = - block_idx * num_heads * (head_size / x) * block_size * x + - head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + - block_offset * x + x_offset; - const int64_t tgt_value_idx = - block_idx * num_heads * head_size * block_size + - head_idx * head_size * block_size + head_offset * block_size + - block_offset; - scalar_t tgt_key = key[src_key_idx]; - scalar_t tgt_value = value[src_value_idx]; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - key_cache[tgt_key_idx] = tgt_key; - value_cache[tgt_value_idx] = tgt_value; - } else { - key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, *k_scale); - value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, *v_scale); - } + const scalar_t* __restrict__ key_src = + key + token_idx * key_stride + head_idx * head_size + h_block * x; + const int64_t src_value_start = + token_idx * value_stride + head_idx * head_size + h_block * x; + + cache_t* __restrict__ key_dst = + key_cache + block_idx * num_heads * h_block_count * block_size * x + + head_idx * h_block_count * block_size * x + h_block * block_size * x + + block_offset * x; + const int64_t tgt_value_start = + block_idx * num_heads * h_block_count * x * block_size + + head_idx * h_block_count * x * block_size + h_block * x * block_size + + block_offset; + + constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + CopyWithScaleOp k_op{k_scale_val}; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + CopyWithScaleOp v_op{v_scale_val}; + + vectorize_with_alignment(key_src, key_dst, x, 0, 1, k_op); + + const scalar_t* __restrict__ value_src = value + src_value_start; + cache_t* __restrict__ value_dst = value_cache + tgt_value_start; +#pragma unroll + for (int i = 0; i < x; i++) { + v_op(value_dst[i * block_size], value_src[i]); } } -// Used by vectorization_utils to copy/convert one element -template -struct CopyWithScaleOp { - float scale; - - __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - dst = static_cast(src); - } else { - dst = fp8::scaled_convert(src, scale); - } - } -}; - template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -601,9 +606,10 @@ void reshape_and_cache( int key_stride = key.stride(0); int value_stride = value.stride(0); + int head_div_x = head_size / x; dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); + dim3 block(std::min(num_heads * head_div_x, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();