From 0ec82edda59aaf5cf3b07aadf4ecce1aa1131add Mon Sep 17 00:00:00 2001 From: Himanshu Jaju Date: Mon, 21 Jul 2025 19:19:23 +0100 Subject: [PATCH] [perf] Speed up align sum kernels (#21079) Signed-off-by: Himanshu Jaju --- .../kernels/benchmark_moe_align_block_size.py | 7 +- csrc/moe/moe_align_sum_kernels.cu | 71 ++++++++++++++----- .../layers/fused_moe/moe_align_block_size.py | 7 +- 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py index 5170ac09dc42a..1af5a21caf465 100644 --- a/benchmarks/kernels/benchmark_moe_align_block_size.py +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): sorted_ids_triton = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device="cuda" ) - sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value - expert_ids_triton = torch.zeros( + expert_ids_triton = torch.empty( (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" ) num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") sorted_ids_vllm = torch.empty_like(sorted_ids_triton) - sorted_ids_vllm.fill_(topk_ids.numel()) - expert_ids_vllm = torch.zeros_like(expert_ids_triton) + expert_ids_vllm = torch.empty_like(expert_ids_triton) num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) # 2. run implementations @@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 462dbd1f8b380..8bbcf5a673fd3 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum) { + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) { extern __shared__ int32_t shared_counts[]; + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[it] = numel; + } + const int warp_id = threadIdx.x / WARP_SIZE; const int my_expert_start = warp_id * experts_per_warp; @@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = 0; - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; - cumsum[i] = - cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; + int expert_count = 0; + int expert_id = threadIdx.x; + if (expert_id < num_experts) { + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + expert_count = CEILDIV(expert_count, block_size) * block_size; + } + + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); + if (expert_id <= num_experts) { + cumsum[expert_id] = cumsum_val; + } + + if (expert_id == num_experts) { + *total_tokens_post_pad = cumsum_val; } __syncthreads(); @@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel( expert_ids[i / block_size] = threadIdx.x; } } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { + expert_ids[i] = 0; + } } template @@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel) { + int32_t block_size, size_t numel, int32_t max_num_tokens_padded) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[it] = numel; + } + const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { + expert_ids[i] = 0; + } + for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; int32_t rank_post_pad = @@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::Tensor cumsum_buffer = - torch::zeros({num_experts + 1}, options_int); + torch::empty({num_experts + 1}, options_int); bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); + topk_ids.numel(), sorted_token_ids.size(0)); } else { auto align_kernel = vllm::moe::moe_align_block_size_kernel; @@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, padded_num_experts, experts_per_warp, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr()); + topk_ids.numel(), cumsum_buffer.data_ptr(), + sorted_token_ids.size(0)); const int block_threads = std::min(256, (int)threads); const int num_blocks = diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 3aae183dfa200..2c9ad509fa98e 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -111,6 +111,8 @@ def moe_align_block_size_triton( dtype=torch.int32, device=topk_ids.device) tokens_per_thread = cdiv(numel, num_experts) + sorted_token_ids.fill_(numel) + expert_ids.zero_() moe_align_block_size_stage1[grid]( topk_ids, @@ -205,11 +207,8 @@ def moe_align_block_size( sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be zeroed out to prevent index out of bounds error while - # mapping global expert ids to local expert ids in expert parallelism. - expert_ids = torch.zeros((max_num_m_blocks, ), + expert_ids = torch.empty((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1),