From 47b93395463d9d2ddc2c1176d6815fdc8e505afc Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 2 Oct 2025 23:35:47 -0400 Subject: [PATCH] [DeepSeek] Improve performance of DS MLA cache kernel (#26132) Signed-off-by: Matthew Bonanni --- csrc/cache_kernels.cu | 124 ++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 65 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1286f5806d4b6..c7eeef8bfa3ad 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -16,7 +16,6 @@ #include #include -#include // FLT_MIN #include #include @@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel( const int64_t dst_idx_start = block_idx * block_stride + block_offset * entry_stride; - // Create 4 tile scales in shared memory - __shared__ float smem[20]; - float* shard_abs_max = smem; - float* tile_scales = smem + 16; - - // For the NoPE part, each tile of 128 elements is handled by 4 warps - // (128 threads). There are 4 total tiles, so 16 warps (512 threads). - // The first thread of the first warp in each tile writes the scale - // value for the tile. The RoPE part (last 64 elements) is handled - // by another 2 warps (64 threads). - // So in total, we use 18 warps (576 threads) per block. + // For the NoPE part, each tile of 128 elements is handled by half of one warp + // (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 threads). + // So in total, we use 3 warps (96 threads) per block. // Cast kv_cache to 16_bit for RoPE values scalar_t* kv_cache_16bit = reinterpret_cast(&kv_cache[dst_idx_start]); - // The last 64 threads handle the RoPE part - if (threadIdx.x >= kv_lora_rank) { - const int8_t pe_idx = threadIdx.x - kv_lora_rank; - const int64_t src_idx = token_idx * k_pe_stride + pe_idx; + // The last warp handles the RoPE part + if (threadIdx.x >= 64) { + // Each thread handles two elements of RoPE + const int8_t pe_idx_start = (threadIdx.x - 64) * 2; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start; + // Vectorized load of two 16-bit values, performed as one 32-bit load + const int32_t vals = *reinterpret_cast(&k_pe[src_idx]); // RoPE values start after the packed 8-bit NoPE values and the // 32-bit scales - const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx; - kv_cache_16bit[dst_idx] = k_pe[src_idx]; + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start; + // Vectorized store of two 16-bit values, performed as one 32-bit store + *reinterpret_cast(&kv_cache_16bit[dst_idx]) = vals; return; } - // Determine the scale for each chunk of NoPE - const int16_t tile_idx = threadIdx.x >> 7; - const int16_t warp_idx = (threadIdx.x & 127) >> 5; - const int16_t lane_idx = threadIdx.x & 31; + // The first two warps handle the NoPE part + const int8_t warp_idx = threadIdx.x >> 5; + const int8_t lane_idx = threadIdx.x & 31; + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); - // Load the NoPE element for this thread into registers - const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x; - const scalar_t src_val = kv_c[src_idx]; + // Each thread handles 8 elements of NoPE + // Load the NoPE elements for this thread into registers + const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8); + // Vectorized load of eight 16-bit values, performed as an int4 load + const int4 vals_i4 = *reinterpret_cast(&kv_c[src_idx_start]); + const scalar_t* vals = reinterpret_cast(&vals_i4); - // Warp-level reduction to find the max absolute value in the warp - float max_abs = fabsf(src_val); + // Max absolute value of this thread's elements + float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])), + fmaxf(fabsf(vals[2]), fabsf(vals[3]))), + fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])), + fmaxf(fabsf(vals[6]), fabsf(vals[7])))); + + // Warp-level reduction to find the max absolute value in each half-warp #pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { -#ifdef USE_ROCM - max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset)); -#else - max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset)); -#endif + for (int offset = 8; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16)); } - // The first lane of each warp in each tile writes the max_abs of this part - // of the tile to shared memory - if (lane_idx == 0) { - shard_abs_max[tile_idx * 4 + warp_idx] = max_abs; - } - __syncthreads(); + // Compute the scale for the tile + float tile_scale = max_abs / 448.f; - // The first lane of the first warp in each tile computes the scale for the - // tile and writes it to shared memory and to kv_cache - if (warp_idx == 0 && lane_idx == 0) { - float4 shard_abs_max_vec = - reinterpret_cast(shard_abs_max)[tile_idx]; - float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y), - fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) / - 448.f; - - // Avoid division by zero in `scaled_convert` - tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN); + // The first lane of each half-warp writes the scale to kv_cache + if ((lane_idx == 0) || (lane_idx == 16)) { float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; - kv_cache_32bit[dst_idx] = tile_scales[tile_idx]; + kv_cache_32bit[dst_idx] = tile_scale; } - __syncthreads(); + // Now all threads in the block scale and write their elements + // NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes) + const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8); - // Now all threads in the block scale and write their element - const float scale_val = tile_scales[tile_idx]; - const int64_t dst_idx = dst_idx_start + threadIdx.x; - kv_cache[dst_idx] = - fp8::scaled_convert( - src_val, scale_val); + uint8_t result[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = + fp8::scaled_convert( + vals[i], tile_scale); + } + + // Store as aligned 64-bit writes + *reinterpret_cast(&kv_cache[dst_idx_base]) = + *reinterpret_cast(result); } template @@ -741,13 +736,12 @@ void concat_and_cache_mla( if (kv_cache_dtype == "fp8_ds_mla") { dim3 grid(num_tokens); - // For the NoPE part, each tile of 128 elements is handled by 4 warps - // (128 threads). There are 4 total tiles, so 16 warps (512 threads). - // The first thread of the first warp in each tile writes the scale - // value for the tile. The RoPE part (last 64 elements) is handled - // by another 2 warps (64 threads). - // So in total, we use 18 warps (576 threads) per block. - dim3 block(576); + // For the NoPE part, each tile of 128 elements is handled by half of one + // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads). + // Lanes 0 and 16 of each warp write the scale values for that warp's tiles. + // The RoPE part (last 64 elements) is handled by another 1 warp (32 + // threads). So in total, we use 3 warps (96 threads) per block. + dim3 block(96); DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, CALL_CONCAT_AND_CACHE_DS_MLA); } else {