[DeepSeek] Improve performance of DS MLA cache kernel (#26132)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-02 23:35:47 -04:00 committed by GitHub
parent 5d5146eee3
commit 47b9339546
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,7 +16,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cfloat> // FLT_MIN
#include <map> #include <map>
#include <vector> #include <vector>
@ -424,84 +423,80 @@ __global__ void concat_and_cache_ds_mla_kernel(
const int64_t dst_idx_start = const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride; block_idx * block_stride + block_offset * entry_stride;
// Create 4 tile scales in shared memory // For the NoPE part, each tile of 128 elements is handled by half of one warp
__shared__ float smem[20]; // (16 threads). There are 4 total tiles, so 2 warps (64 threads).
float* shard_abs_max = smem; // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
float* tile_scales = smem + 16; // The RoPE part (last 64 elements) is handled by another 1 warp (32 threads).
// So in total, we use 3 warps (96 threads) per block.
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
// Cast kv_cache to 16_bit for RoPE values // Cast kv_cache to 16_bit for RoPE values
scalar_t* kv_cache_16bit = scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]); reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last 64 threads handle the RoPE part // The last warp handles the RoPE part
if (threadIdx.x >= kv_lora_rank) { if (threadIdx.x >= 64) {
const int8_t pe_idx = threadIdx.x - kv_lora_rank; // Each thread handles two elements of RoPE
const int64_t src_idx = token_idx * k_pe_stride + pe_idx; const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
// Vectorized load of two 16-bit values, performed as one 32-bit load
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
// RoPE values start after the packed 8-bit NoPE values and the // RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales // 32-bit scales
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx; const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
kv_cache_16bit[dst_idx] = k_pe[src_idx]; // Vectorized store of two 16-bit values, performed as one 32-bit store
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
return; return;
} }
// Determine the scale for each chunk of NoPE // The first two warps handle the NoPE part
const int16_t tile_idx = threadIdx.x >> 7; const int8_t warp_idx = threadIdx.x >> 5;
const int16_t warp_idx = (threadIdx.x & 127) >> 5; const int8_t lane_idx = threadIdx.x & 31;
const int16_t lane_idx = threadIdx.x & 31; const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
// Load the NoPE element for this thread into registers // Each thread handles 8 elements of NoPE
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x; // Load the NoPE elements for this thread into registers
const scalar_t src_val = kv_c[src_idx]; const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
// Vectorized load of eight 16-bit values, performed as an int4 load
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
// Warp-level reduction to find the max absolute value in the warp // Max absolute value of this thread's elements
float max_abs = fabsf(src_val); float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])),
fmaxf(fabsf(vals[2]), fabsf(vals[3]))),
fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])),
fmaxf(fabsf(vals[6]), fabsf(vals[7]))));
// Warp-level reduction to find the max absolute value in each half-warp
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset /= 2) { for (int offset = 8; offset > 0; offset /= 2) {
#ifdef USE_ROCM max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16));
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
#else
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
#endif
} }
// The first lane of each warp in each tile writes the max_abs of this part // Compute the scale for the tile
// of the tile to shared memory float tile_scale = max_abs / 448.f;
if (lane_idx == 0) {
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
}
__syncthreads();
// The first lane of the first warp in each tile computes the scale for the // The first lane of each half-warp writes the scale to kv_cache
// tile and writes it to shared memory and to kv_cache if ((lane_idx == 0) || (lane_idx == 16)) {
if (warp_idx == 0 && lane_idx == 0) {
float4 shard_abs_max_vec =
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
448.f;
// Avoid division by zero in `scaled_convert`
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]); float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scales[tile_idx]; kv_cache_32bit[dst_idx] = tile_scale;
} }
__syncthreads(); // Now all threads in the block scale and write their elements
// NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes)
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
// Now all threads in the block scale and write their element uint8_t result[8];
const float scale_val = tile_scales[tile_idx]; #pragma unroll
const int64_t dst_idx = dst_idx_start + threadIdx.x; for (int i = 0; i < 8; i++) {
kv_cache[dst_idx] = result[i] =
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>( fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
src_val, scale_val); vals[i], tile_scale);
}
// Store as aligned 64-bit writes
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
*reinterpret_cast<const uint64_t*>(result);
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@ -741,13 +736,12 @@ void concat_and_cache_mla(
if (kv_cache_dtype == "fp8_ds_mla") { if (kv_cache_dtype == "fp8_ds_mla") {
dim3 grid(num_tokens); dim3 grid(num_tokens);
// For the NoPE part, each tile of 128 elements is handled by 4 warps // For the NoPE part, each tile of 128 elements is handled by half of one
// (128 threads). There are 4 total tiles, so 16 warps (512 threads). // warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
// The first thread of the first warp in each tile writes the scale // Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
// value for the tile. The RoPE part (last 64 elements) is handled // The RoPE part (last 64 elements) is handled by another 1 warp (32
// by another 2 warps (64 threads). // threads). So in total, we use 3 warps (96 threads) per block.
// So in total, we use 18 warps (576 threads) per block. dim3 block(96);
dim3 block(576);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA); CALL_CONCAT_AND_CACHE_DS_MLA);
} else { } else {