[perf] Speed up align sum kernels (#21079)

Signed-off-by: Himanshu Jaju <hj@mistral.ai>
This commit is contained in:
Himanshu Jaju 2025-07-21 19:19:23 +01:00 committed by GitHub
parent 005ae9be6c
commit 0ec82edda5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 25 deletions

View File

@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
expert_ids_triton = torch.zeros(
expert_ids_triton = torch.empty(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
expert_ids_vllm = torch.empty_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
# 2. run implementations
@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")

View File

@ -1,6 +1,7 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum) {
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
extern __shared__ int32_t shared_counts[];
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[it] = numel;
}
const int warp_id = threadIdx.x / WARP_SIZE;
const int my_expert_start = warp_id * experts_per_warp;
@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0;
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
cumsum[i] =
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
int expert_count = 0;
int expert_id = threadIdx.x;
if (expert_id < num_experts) {
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
expert_count = CEILDIV(expert_count, block_size) * block_size;
}
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
if (expert_id <= num_experts) {
cumsum[expert_id] = cumsum_val;
}
if (expert_id == num_experts) {
*total_tokens_post_pad = cumsum_val;
}
__syncthreads();
@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
expert_ids[i / block_size] = threadIdx.x;
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
expert_ids[i] = 0;
}
}
template <typename scalar_t>
@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel) {
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[it] = numel;
}
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
expert_ids[i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad =
@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024");
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
torch::Tensor cumsum_buffer =
torch::zeros({num_experts + 1}, options_int);
torch::empty({num_experts + 1}, options_int);
bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64);
@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
topk_ids.numel(), sorted_token_ids.size(0));
} else {
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
padded_num_experts, experts_per_warp, block_size,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
sorted_token_ids.size(0));
const int block_threads = std::min(256, (int)threads);
const int num_blocks =

View File

@ -111,6 +111,8 @@ def moe_align_block_size_triton(
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()
moe_align_block_size_stage1[grid](
topk_ids,
@ -205,11 +207,8 @@ def moe_align_block_size(
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),