Lora MoE Align Improvements (#29257)

Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
gnovack 2025-12-08 18:35:16 -08:00 committed by GitHub
parent db14f61f2d
commit ea657f2078
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 360 additions and 249 deletions

View File

@ -944,7 +944,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@ -14,7 +14,6 @@
namespace vllm { namespace vllm {
namespace moe { namespace moe {
namespace batched_moe_align_block_size { namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. // Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
@ -80,23 +79,30 @@ __global__ void batched_moe_align_block_size_kernel(
} // namespace batched_moe_align_block_size } // namespace batched_moe_align_block_size
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel( __device__ void _moe_align_block_size(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t* __restrict__ expert_map, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded, size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
bool has_expert_map) { int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id,
int32_t topk_num, int32_t* token_mask, bool has_expert_map) {
extern __shared__ int32_t shared_counts[]; extern __shared__ int32_t shared_counts[];
// Use a separate threadblock to fill sorted_token_ids. // Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
int expert_ids_offset = max_num_m_blocks * model_offset;
int cumsum_offset = (num_experts + 1) * model_offset;
// Use separate threadblocks to fill sorted_token_ids.
// This is safe since the current kernel does not use sorted_token_ids. // This is safe since the current kernel does not use sorted_token_ids.
if (blockIdx.x == 1) { if (blockIdx.x % 2) {
// Initialize sorted_token_ids with numel // Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += blockDim.x) { it += blockDim.x) {
sorted_token_ids[it] = numel; sorted_token_ids[sorted_token_ids_offset + it] = numel;
} }
return; return;
} }
@ -127,7 +133,9 @@ __global__ void moe_align_block_size_kernel(
} }
int warp_idx = expert_id / experts_per_warp; int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp; int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset],
mask);
} }
__syncthreads(); __syncthreads();
@ -148,77 +156,44 @@ __global__ void moe_align_block_size_kernel(
int cumsum_val; int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
if (expert_id <= num_experts) { if (expert_id <= num_experts) {
cumsum[expert_id] = cumsum_val; cumsum[cumsum_offset + expert_id] = cumsum_val;
} }
if (expert_id == num_experts) { if (expert_id == num_experts) {
*total_tokens_post_pad = cumsum_val; total_tokens_post_pad[model_offset] = cumsum_val;
} }
__syncthreads(); __syncthreads();
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; for (int i = cumsum[cumsum_offset + threadIdx.x];
i += block_size) { i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;
} }
} }
// Fill remaining expert_ids with 0 // Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; const size_t fill_start_idx =
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
expert_ids[i] = 0; expert_ids[expert_ids_offset + i] = inactive_expert_id;
}
}
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
bool has_expert_map) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
} }
} }
template <typename scalar_t, int32_t fill_threads> template <typename scalar_t, int32_t fill_threads>
__global__ void moe_align_block_size_small_batch_expert_kernel( __device__ void _moe_align_block_size_small_batch_expert(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size, int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
size_t numel, int32_t max_num_tokens_padded, bool has_expert_map) { size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks,
int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num,
int32_t* token_mask, bool has_expert_map) {
// Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
int expert_ids_offset = max_num_m_blocks * model_offset;
// Use an additional group of threads to fill sorted_token_ids. // Use an additional group of threads to fill sorted_token_ids.
// Since the current kernel will use sorted_token_ids afterward, // Since the current kernel will use sorted_token_ids afterward,
// we fill sorted_token_ids within the same threadblock to make // we fill sorted_token_ids within the same threadblock to make
@ -227,7 +202,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
// Initialize sorted_token_ids with numel // Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += fill_threads) { it += fill_threads) {
sorted_token_ids[it] = numel; sorted_token_ids[sorted_token_ids_offset + it] = numel;
} }
// Three __syncthreads() corresponding to the other threads // Three __syncthreads() corresponding to the other threads
__syncthreads(); __syncthreads();
@ -254,7 +229,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
// filter invalid expert // filter invalid expert
if (expert_id == -1) continue; if (expert_id == -1) continue;
} }
++tokens_cnts[(tid + 1) * num_experts + expert_id]; int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
tokens_cnts[(tid + 1) * num_experts + expert_id] += mask;
} }
__syncthreads(); __syncthreads();
@ -277,22 +253,22 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) *
block_size; block_size;
} }
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); total_tokens_post_pad[model_offset] =
static_cast<int32_t>(cumsum[num_experts]);
} }
__syncthreads(); __syncthreads();
if (tid < num_experts) { if (tid < num_experts) {
for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {
expert_ids[i / block_size] = tid; expert_ids[expert_ids_offset + i / block_size] = tid;
} }
} }
// Fill remaining expert_ids with 0 // Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid; const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
for (size_t i = fill_start_idx; i < expert_ids_size; i += stride) { expert_ids[expert_ids_offset + i] = inactive_expert_id;
expert_ids[i] = 0;
} }
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
@ -304,11 +280,195 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
} }
int32_t rank_post_pad = int32_t rank_post_pad =
tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[tid * num_experts + expert_id]; if (token_mask == nullptr || token_mask[i / topk_num]) {
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;
++tokens_cnts[tid * num_experts + expert_id];
}
} }
} }
template <typename scalar_t>
__device__ void _count_and_sort_expert_tokens(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask,
int32_t model_offset, int32_t topk_num, bool has_expert_map) {
const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.y;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
if (token_mask == nullptr || token_mask[i / topk_num]) {
int32_t rank_post_pad = atomicAdd(
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1);
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] =
i;
}
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
int32_t topk_num, bool has_expert_map) {
_moe_align_block_size(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
0, 0, topk_num, nullptr, has_expert_map);
}
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) {
_count_and_sort_expert_tokens(
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map);
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t, int32_t fill_threads>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
size_t numel, int32_t max_num_tokens_padded, int32_t topk_num,
bool has_expert_map) {
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded,
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr,
has_expert_map);
}
template <typename scalar_t>
__global__ void moe_lora_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping,
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
int max_loras, size_t numel, int max_num_tokens_padded,
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids, int32_t topk_num,
int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* __restrict__ cumsum, int32_t experts_per_warp,
int32_t padded_num_experts, int32_t* lora_ids,
int32_t* __restrict__ token_mask, bool has_expert_map) {
int lora_idx = blockIdx.x / 2;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
// Populate the token_mask based on the token-LoRA mapping
int num_tokens = numel / topk_num;
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
for (int i = 0; i < num_tokens; i++) {
token_mask[(lora_id * num_tokens) + i] =
(int)token_lora_mapping[i] == lora_id;
}
}
__syncthreads();
_moe_align_block_size(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num,
&token_mask[(lora_id * num_tokens)], has_expert_map);
}
template <typename scalar_t>
__global__ void lora_count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask,
int32_t* lora_ids, bool has_expert_map) {
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1) {
return;
}
int num_tokens = numel / topk_num;
_count_and_sort_expert_tokens(
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id,
topk_num, has_expert_map);
}
template <typename scalar_t, int32_t fill_threads>
__global__ void moe_lora_align_block_size_small_batch_expert_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
int max_loras, size_t numel, int max_num_tokens_padded,
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids, int topk_num,
int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids,
int32_t* token_mask, bool has_expert_map) {
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
int num_tokens = numel / topk_num;
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
for (int i = 0; i < num_tokens; i++) {
token_mask[(lora_id * num_tokens) + i] =
(int)token_lora_mapping[i] == lora_id;
}
}
__syncthreads();
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks,
-1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)],
has_expert_map);
}
} // namespace moe } // namespace moe
} // namespace vllm } // namespace vllm
@ -365,7 +525,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
expert_map.data_ptr<int32_t>(), num_experts, block_size, expert_map.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), sorted_token_ids.size(0), has_expert_map); topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1),
has_expert_map);
} else { } else {
torch::Tensor cumsum_buffer = torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int); torch::empty({num_experts + 1}, options_int);
@ -386,21 +547,23 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts, expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts,
experts_per_warp, block_size, topk_ids.numel(), experts_per_warp, block_size, topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0), cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0),
has_expert_map); topk_ids.size(1), has_expert_map);
const int block_threads = std::min(256, (int)threads); const int block_threads = std::min(256, (int)threads);
const int num_blocks = const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads; (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535; const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks); const int actual_blocks = std::min(num_blocks, max_blocks);
dim3 gridDims(1, actual_blocks);
auto sort_kernel = auto sort_kernel =
vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>; vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( sort_kernel<<<gridDims, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(),
topk_ids.numel(), num_experts, has_expert_map); topk_ids.numel(), num_experts, sorted_token_ids.size(0),
topk_ids.size(1), has_expert_map);
} }
}); });
} }
@ -474,3 +637,123 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
break; break;
} }
} }
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts =
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024");
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
torch::Tensor token_mask =
torch::empty({max_loras * topk_ids.size(0)}, options_int);
bool has_expert_map = maybe_expert_map.has_value();
torch::Tensor expert_map;
if (has_expert_map) {
expert_map = maybe_expert_map.value();
} else {
expert_map = torch::empty({0}, options_int);
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t num_thread = max((int32_t)num_experts, 128);
const int32_t shared_mem =
(num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false, "Shared memory usage exceeds device limit.");
}
// threadIdx.x >= fill_threads: counting experts and aligning
// threadIdx.x < fill_threads: filling sorted_token_ids
constexpr int32_t fill_threads = 256;
dim3 blockDim(num_thread + fill_threads);
auto kernel =
vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
scalar_t, fill_threads>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size,
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>(),
token_mask.data_ptr<int32_t>(), has_expert_map);
} else {
int num_thread = 1024;
dim3 blockDim(num_thread);
size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE);
size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t);
// cumsum buffer
torch::Tensor cumsum =
torch::zeros({max_loras * (num_experts + 1)}, options_int);
auto align_kernel =
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
// launch two threadblocks for each lora
// blockIdx.x % 2 == 0: counting experts and aligning
// blockIdx.x % 2 == 1: filling sorted_token_ids
align_kernel<<<max_loras * 2, blockDim, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size,
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
WARP_SIZE, padded_num_experts, lora_ids.data_ptr<int32_t>(),
token_mask.data_ptr<int32_t>(), has_expert_map);
const int block_threads = std::min(256, (int)num_thread);
const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
dim3 gridDims(max_loras, actual_blocks);
auto sort_kernel =
vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
expert_map.data_ptr<int32_t>(), topk_ids.numel(), num_experts,
max_num_tokens_padded, topk_num, token_mask.data_ptr<int32_t>(),
lora_ids.data_ptr<int32_t>(), has_expert_map);
}
});
}

View File

@ -1,174 +0,0 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
});
}

View File

@ -27,7 +27,7 @@ void moe_lora_align_block_size(
int64_t max_num_tokens_padded, int64_t max_num_m_blocks, int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids); torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor !experts_ids," " Tensor !experts_ids,"
" Tensor !num_tokens_post_pad," " Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled," " Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () "); " Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM

View File

@ -32,7 +32,7 @@ def sample_data(num_experts, max_loras, num_tokens, topk_num):
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 @pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
@pytest.mark.parametrize("topk_num", [6]) @pytest.mark.parametrize("topk_num", [6])
@pytest.mark.parametrize("num_experts", [64, 128]) @pytest.mark.parametrize("num_experts", [64, 128, 256, 512])
@pytest.mark.parametrize("max_loras", [2, 32]) @pytest.mark.parametrize("max_loras", [2, 32])
@pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("block_size", [16])
def test_moe_lora_align_block_size( def test_moe_lora_align_block_size(

View File

@ -1961,6 +1961,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None: ) -> None:
torch.ops._moe_C.moe_lora_align_block_size( torch.ops._moe_C.moe_lora_align_block_size(
topk_ids, topk_ids,
@ -1975,6 +1976,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
adapter_enabled, adapter_enabled,
lora_ids, lora_ids,
expert_map,
) )