[Kernel] optimize moe_align_block_size for cuda graph and large num_experts (e.g. DeepSeek-V3) (#12222)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Jinzhen Lin 2025-01-21 08:42:16 +08:00 committed by GitHub
parent 06a760d6e8
commit 750f4cabfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 37 deletions

View File

@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
} }
} // namespace } // namespace
template <typename scalar_t> template <typename scalar_t, typename token_cnts_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* expert_ids,
@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
int32_t* tokens_cnts = token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
int32_t* cumsum =
shared_mem +
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
block_size) * block_size) *
block_size; block_size;
} }
*total_tokens_post_pad = cumsum[num_experts]; *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
} }
__syncthreads(); __syncthreads();
@ -224,26 +220,44 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// If we have very large number of experts, we can no longer use shared int device_max_shared_mem;
// memory. auto dev = topk_ids.get_device();
// TODO(simon): the right solution should be calculating the exact right cudaDeviceGetAttribute(&device_max_shared_mem,
// amount of shared memory and use that. The num_experts >= 256 is just a cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
// temporary solution to unblock Deepseek V3.
if (num_experts >= 256) { const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_i32 =
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
const int32_t shared_mem_i16 =
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
(num_experts + 1) * sizeof(int32_t);
bool use_global_memory = false;
bool use_i16 = false; // Use uint16_t for shared memory token counts
if (shared_mem_i16 > device_max_shared_mem) {
use_global_memory = true;
} else if (shared_mem_i32 > device_max_shared_mem &&
topk_ids.numel() <= 65535) {
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
// element value of token_cnts would also smaller than 65535,
// so we can use uint16 as dtype of token_cnts
use_i16 = true;
}
if (use_global_memory) {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors // tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t mem_tokens_cnts = auto options_int = torch::TensorOptions()
((num_experts + 1) * num_experts) * sizeof(int32_t); .dtype(torch::kInt)
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); .device(topk_ids.device());
// allocate global memory torch::Tensor token_cnts_buffer =
int32_t* tokens_cnts; torch::empty({(num_experts + 1) * num_experts}, options_int);
int32_t* cumsum; torch::Tensor cumsum_buffer =
cudaMalloc(&tokens_cnts, mem_tokens_cnts); torch::empty({num_experts + 1}, options_int);
cudaMalloc(&cumsum, mem_cumsum);
auto kernel = auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
@ -252,25 +266,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum); topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
cudaFree(tokens_cnts); cumsum_buffer.data_ptr<int32_t>());
cudaFree(cumsum); });
} else if (use_i16) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// set dynamic shared mem
auto kernel =
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem_i16));
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
}); });
} else { } else {
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` auto kernel =
// tensors vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem)); (void*)kernel, shared_mem_i32));
kernel<<<1, num_thread, shared_mem, stream>>>( kernel<<<1, num_thread, shared_mem_i32, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),

View File

@ -607,7 +607,7 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['deepseek_v3', 'mllama'] MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
and not self.enforce_eager): and not self.enforce_eager):
logger.warning( logger.warning(