#include #include #include #include #include #include "../cuda_compat.h" #include "../dispatch_utils.h" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace moe { namespace { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { // don't worry about overflow because num_experts is relatively small return row * total_col + col; } } // namespace template __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; extern __shared__ int32_t shared_mem[]; int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); // 2d tensor with shape (blockDim.x + 1, num_experts) for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } /** * In the first step we compute token_cnts[thread_index + 1][expert_index], * which counts how many tokens in the token shard of thread_index are * assigned to expert expert_index. */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; } __syncthreads(); // For each expert we accumulate the token counts from the different threads. if (threadIdx.x < num_experts) { tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } __syncthreads(); // We accumulate the token counts of all experts in thread 0. if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; } *total_tokens_post_pad = static_cast(cumsum[num_experts]); } __syncthreads(); /** * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } } /** * Each thread processes a token shard, calculating the index of each token * after sorting by expert number. Given the example topk_ids = * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a * padding value(preset in python). */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { int32_t expert_id = topk_ids[i]; /** The cumsum[expert_id] stores the starting index of the tokens that the * expert with expert_id needs to process, and * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens * processed by the expert with expert_id within the current thread's token * shard. */ int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } // TODO(simon): this is temporarily adapted from // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 // we did this to unblock Deepseek V3 but there should be a better // implementation to manage shared memory. template __global__ void moe_align_block_size_global_mem_kernel( scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; for (int i = 0; i < num_experts; ++i) { tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } /** * In the first step we compute token_cnts[thread_index + 1][expert_index], * which counts how many tokens in the token shard of thread_index are * assigned to expert expert_index. */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; } __syncthreads(); // For each expert we accumulate the token counts from the different threads. if (threadIdx.x < num_experts) { tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; for (int i = 1; i <= blockDim.x; ++i) { tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } } __syncthreads(); // We accumulate the token counts of all experts in thread 0. if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } __syncthreads(); /** * For each expert, each thread processes the tokens of the corresponding * blocks and stores the corresponding expert_id for each block. */ if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } } /** * Each thread processes a token shard, calculating the index of each token * after sorting by expert number. Given the example topk_ids = * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a * padding value(preset in python). */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { int32_t expert_id = topk_ids[i]; /** The cumsum[expert_id] stores the starting index of the tokens that the * expert with expert_id needs to process, and * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens * processed by the expert with expert_id within the current thread's token * shard. */ int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } } // taken from // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template __global__ void sgl_moe_align_block_size_kernel( scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel, int32_t* cumsum) { __shared__ int32_t shared_counts[32][8]; const int warp_id = threadIdx.x / 32; const int experts_per_warp = 8; const int my_expert_start = warp_id * experts_per_warp; // Initialize shared_counts for this warp's experts for (int i = 0; i < experts_per_warp; ++i) { if (my_expert_start + i < num_experts) { shared_counts[warp_id][i] = 0; } } __syncthreads(); const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; atomicAdd(&shared_counts[warp_idx][expert_offset], 1); } __syncthreads(); // Single thread computes cumulative sum and total tokens if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { int expert_count = 0; int warp_idx = (i - 1) / experts_per_warp; int expert_offset = (i - 1) % experts_per_warp; expert_count = shared_counts[warp_idx][expert_offset]; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } __syncthreads(); // Assign expert IDs to blocks if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { expert_ids[i / block_size] = threadIdx.x; } } } // taken from // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template __global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, int32_t* cumsum_buffer, size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } } template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., topk, d] const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { scalar_t x = 0.0; #pragma unroll for (int k = 0; k < TOPK; ++k) { x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); } out[token_idx * d + idx] = x; } } } // namespace moe } // namespace vllm void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int device_max_shared_mem; auto dev = topk_ids.get_device(); cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); const int32_t shared_mem_i32 = ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); const int32_t shared_mem_i16 = ((num_thread + 1) * num_experts) * sizeof(uint16_t) + (num_experts + 1) * sizeof(int32_t); bool use_global_memory = false; bool use_i16 = false; // Use uint16_t for shared memory token counts if (shared_mem_i32 < device_max_shared_mem) { // Do nothing in this case. We're all set to use int32_t token counts } else if (shared_mem_i16 < device_max_shared_mem && topk_ids.numel() <= 65535) { // when nelements of topk_ids is smaller than 65535 (max value of uint16), // element value of token_cnts would also smaller than 65535, // so we can use uint16 as dtype of token_cnts use_i16 = true; } else { use_global_memory = true; } if (use_global_memory) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); auto options_int = torch::TensorOptions() .dtype(torch::kInt) .device(topk_ids.device()); torch::Tensor token_cnts_buffer = torch::empty({(num_experts + 1) * num_experts}, options_int); torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); auto kernel = vllm::moe::moe_align_block_size_global_mem_kernel; kernel<<<1, num_thread, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel(), token_cnts_buffer.data_ptr(), cumsum_buffer.data_ptr()); }); } else if (use_i16) { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( (void*)kernel, shared_mem_i16)); kernel<<<1, num_thread, shared_mem_i16, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); }); } else { VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( (void*)kernel, shared_mem_i32)); kernel<<<1, num_thread, shared_mem_i32, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); }); } } void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); VLLM_DISPATCH_INTEGRAL_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::Tensor cumsum_buffer = torch::zeros({num_experts + 1}, options_int); auto align_kernel = vllm::moe::sgl_moe_align_block_size_kernel; align_kernel<<<1, 1024, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); const int block_threads = 256; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; const int max_blocks = 65535; const int actual_blocks = std::min(num_blocks, max_blocks); auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel; sort_kernel<<>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), topk_ids.numel()); }); } void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); const int num_tokens = output.numel() / hidden_size; const int topk = input.size(1); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (topk) { case 2: VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { vllm::moe::moe_sum_kernel<<>>( output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 3: VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { vllm::moe::moe_sum_kernel<<>>( output.data_ptr(), input.data_ptr(), hidden_size); }); break; case 4: VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] { vllm::moe::moe_sum_kernel<<>>( output.data_ptr(), input.data_ptr(), hidden_size); }); break; default: at::sum_out(output, input, 1); break; } }