From eefbf4a68b7b0a5b8364a59647906be1b7f043e2 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 1 Aug 2025 19:18:51 -0400 Subject: [PATCH] [Perf] Optimize `reshape_and_cache_flash` CUDA Kernel (#22036) Signed-off-by: yewentao256 --- .../benchmark_reshape_and_cache_flash.py | 156 ++++++++++++++++++ csrc/cache_kernels.cu | 92 ++++++++--- 2 files changed, 225 insertions(+), 23 deletions(-) create mode 100644 benchmarks/kernels/benchmark_reshape_and_cache_flash.py diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py new file mode 100644 index 0000000000000..d4648c18f31d5 --- /dev/null +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import random +import time + +import torch +from tabulate import tabulate + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random_flash, +) + +logger = init_logger(__name__) + + +@torch.inference_mode() +def run_benchmark( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + kv_cache_dtype: str, + kv_cache_layout: str, + num_iters: int, + device: str = "cuda", +) -> float: + """Return latency (seconds) for given num_tokens.""" + + if kv_cache_dtype == "fp8" and head_size % 16: + raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.") + + current_platform.seed_everything(42) + torch.set_default_device(device) + + # create random key / value tensors [T, H, D]. + key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device) + value = torch.randn_like(key) + + # prepare the slot mapping. + # each token is assigned a unique slot in the KV-cache. + num_slots = block_size * num_blocks + if num_tokens > num_slots: + raise ValueError("num_tokens cannot exceed the total number of cache slots") + slot_mapping_lst = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + key_caches, value_caches = create_kv_caches_with_random_flash( + num_blocks, + block_size, + 1, # num_layers + num_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + cache_layout=kv_cache_layout, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # compute per-kernel scaling factors for fp8 conversion (if used). + k_scale = (key.amax() / 64.0).to(torch.float32) + v_scale = (value.amax() / 64.0).to(torch.float32) + + def run_cuda_benchmark(n_iters: int) -> float: + nonlocal key, value, key_cache, value_cache, slot_mapping + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(n_iters): + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + torch.cuda.synchronize() + end = time.perf_counter() + return (end - start) / n_iters + + # warm-up + run_cuda_benchmark(3) + + lat = run_cuda_benchmark(num_iters) + + # free tensors to mitigate OOM when sweeping + del key, value, key_cache, value_cache, slot_mapping + torch.cuda.empty_cache() + + return lat + + +def main(args): + rows = [] + for layout in ["NHD", "HND"]: + for exp in range(1, 17): + n_tok = 2**exp + lat = run_benchmark( + num_tokens=n_tok, + num_heads=args.num_heads, + head_size=args.head_size, + block_size=args.block_size, + num_blocks=args.num_blocks, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + kv_cache_dtype=args.kv_cache_dtype, + kv_cache_layout=layout, + num_iters=args.iters, + device="cuda", + ) + rows.append([n_tok, layout, f"{lat * 1e6:.3f}"]) + + print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"])) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + + parser.add_argument("--num-heads", type=int, default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) + parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) + parser.add_argument("--num-blocks", type=int, default=128 * 512) + + parser.add_argument( + "--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="bfloat16", + ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + ) + + parser.add_argument("--iters", type=int, default=100) + args = parser.parse_args() + + main(args) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 88559c8fe7183..131dcb15cd7e9 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -5,6 +5,7 @@ #include "cuda_utils.h" #include "cuda_compat.h" #include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" #ifdef USE_ROCM #include "quantization/fp8/amd/quant_utils.cuh" @@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel( } } +// Used by vectorization_utils to copy/convert one element +template +struct CopyWithScaleOp { + float scale; + + __device__ __forceinline__ void operator()(OutT& dst, const InT src) const { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst = static_cast(src); + } else { + dst = fp8::scaled_convert(src, scale); + } + } +}; + template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, - // head_size] - cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, - // head_size] + cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below + cache_t* __restrict__ value_cache, // same above const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t block_stride, const int64_t page_stride, const int64_t head_stride, const int64_t key_stride, @@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel( } const int64_t block_idx = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - const int n = num_heads * head_size; - for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int64_t src_key_idx = token_idx * key_stride + i; - const int64_t src_value_idx = token_idx * value_stride + i; - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int64_t tgt_key_value_idx = block_idx * block_stride + - block_offset * page_stride + - head_idx * head_stride + head_offset; - scalar_t tgt_key = key[src_key_idx]; - scalar_t tgt_value = value[src_value_idx]; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - key_cache[tgt_key_value_idx] = tgt_key; - value_cache[tgt_key_value_idx] = tgt_value; - } else { - key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, *k_scale); - value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, *v_scale); + const int n_elems = num_heads * head_size; + + // pointers to the beginning of the source row for this token. + const scalar_t* __restrict__ key_src = key + token_idx * key_stride; + const scalar_t* __restrict__ value_src = value + token_idx * value_stride; + + // find the start position inside the kv-cache for this token. + cache_t* __restrict__ key_dst = + key_cache + block_idx * block_stride + block_offset * page_stride; + cache_t* __restrict__ value_dst = + value_cache + block_idx * block_stride + block_offset * page_stride; + + // this is true for the NHD layout where `head_stride == head_size` + const bool is_contiguous_heads = (head_stride == head_size); + + float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale; + float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale; + constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4; + CopyWithScaleOp k_op{k_scale_val}; + CopyWithScaleOp v_op{v_scale_val}; + if (is_contiguous_heads) { + // NHD layout + // kv cache: [num_blocks, block_size, num_heads, head_size] + vectorize_with_alignment(key_src, key_dst, n_elems, threadIdx.x, + blockDim.x, k_op); + + vectorize_with_alignment(value_src, value_dst, n_elems, + threadIdx.x, blockDim.x, v_op); + + } else { + // HND layout: heads are strided, but each head_size segment is contiguous + // kv cache: [num_blocks, num_heads, block_size, head_size] + const int lane = threadIdx.x & 31; // 0..31 within warp + const int warp_id = threadIdx.x >> 5; // warp index within block + const int warps_per_block = blockDim.x >> 5; + + for (int head = warp_id; head < num_heads; head += warps_per_block) { + const scalar_t* __restrict__ k_src_h = key_src + head * head_size; + const scalar_t* __restrict__ v_src_h = value_src + head * head_size; + + cache_t* __restrict__ k_dst_h = + key_dst + static_cast(head) * head_stride; + cache_t* __restrict__ v_dst_h = + value_dst + static_cast(head) * head_stride; + + // within each head, let the 32 threads of the warp perform the vector + // copy + vectorize_with_alignment(k_src_h, k_dst_h, head_size, lane, 32, + k_op); + + vectorize_with_alignment(v_src_h, v_dst_h, head_size, lane, 32, + v_op); } } }