[Kernel][MoE] optimize moe_align_block_size (#29642)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jinzhen Lin 2025-12-07 17:58:47 +08:00 committed by GitHub
parent 1b0482b9d1
commit 879ddb09c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 195 additions and 63 deletions

View File

@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
num_tokens_range = [1, 16, 256, 4096] num_tokens_range = [1, 16, 256, 4096]
num_experts_range = [16, 64, 224, 256, 280, 512] num_experts_range = [16, 64, 224, 256, 280, 512]
topk_range = [1, 2, 8] topk_range = [1, 2, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) ep_size_range = [1, 8]
configs = list(
itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range)
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"], x_names=["num_tokens", "num_experts", "topk", "ep_size"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["vllm"], line_vals=["vllm"],
@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
args={}, args={},
) )
) )
def benchmark(num_tokens, num_experts, topk, provider): def benchmark(num_tokens, num_experts, topk, ep_size, provider):
"""Benchmark function for Triton.""" """Benchmark function for Triton."""
block_size = 256 block_size = 256
torch.cuda.manual_seed_all(0)
topk_ids = get_topk_ids(num_tokens, num_experts, topk) topk_ids = get_topk_ids(num_tokens, num_experts, topk)
e_map = None
if ep_size != 1:
local_e = num_experts // ep_size
e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "vllm": if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size(topk_ids, block_size, num_experts), lambda: moe_align_block_size(
topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True
),
quantiles=quantiles, quantiles=quantiles,
) )

View File

@ -83,14 +83,22 @@ template <typename scalar_t>
__global__ void moe_align_block_size_kernel( __global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) { size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
bool has_expert_map) {
extern __shared__ int32_t shared_counts[]; extern __shared__ int32_t shared_counts[];
// Initialize sorted_token_ids with numel // Use a separate threadblock to fill sorted_token_ids.
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { // This is safe since the current kernel does not use sorted_token_ids.
sorted_token_ids[it] = numel; if (blockIdx.x == 1) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += blockDim.x) {
sorted_token_ids[it] = numel;
}
return;
} }
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
@ -112,6 +120,11 @@ __global__ void moe_align_block_size_kernel(
if (expert_id >= num_experts) { if (expert_id >= num_experts) {
continue; continue;
} }
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
int warp_idx = expert_id / experts_per_warp; int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp; int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
@ -163,7 +176,8 @@ template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel( __global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
size_t numel, int32_t num_experts) { int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
bool has_expert_map) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
@ -172,6 +186,11 @@ __global__ void count_and_sort_expert_tokens_kernel(
if (expert_id >= num_experts) { if (expert_id >= num_experts) {
continue; continue;
} }
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
} }
@ -193,50 +212,69 @@ __global__ void moe_sum_kernel(
} }
} }
template <typename scalar_t> template <typename scalar_t, int32_t fill_threads>
__global__ void moe_align_block_size_small_batch_expert_kernel( __global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t* __restrict__ total_tokens_post_pad,
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) { int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
// Initialize sorted_token_ids with numel size_t numel, int32_t max_num_tokens_padded, bool has_expert_map) {
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { // Use an additional group of threads to fill sorted_token_ids.
sorted_token_ids[it] = numel; // Since the current kernel will use sorted_token_ids afterward,
// we fill sorted_token_ids within the same threadblock to make
// synchronization easier.
if (threadIdx.x < fill_threads) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += fill_threads) {
sorted_token_ids[it] = numel;
}
// Three __syncthreads() corresponding to the other threads
__syncthreads();
__syncthreads();
__syncthreads();
return;
} }
const size_t tid = threadIdx.x; const size_t tid = threadIdx.x - fill_threads;
const size_t stride = blockDim.x; const size_t stride = blockDim.x - fill_threads;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem; int32_t* cumsum = shared_mem;
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; tokens_cnts[(tid + 1) * num_experts + i] = 0;
} }
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; int32_t expert_id = topk_ids[i];
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid expert
if (expert_id == -1) continue;
}
++tokens_cnts[(tid + 1) * num_experts + expert_id];
} }
__syncthreads(); __syncthreads();
if (threadIdx.x < num_experts) { if (tid < num_experts) {
tokens_cnts[threadIdx.x] = 0; tokens_cnts[tid] = 0;
for (int i = 1; i <= blockDim.x; ++i) { for (int i = 1; i <= stride; ++i) {
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[i * num_experts + tid] +=
tokens_cnts[(i - 1) * num_experts + threadIdx.x]; tokens_cnts[(i - 1) * num_experts + tid];
} }
} }
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (tid == 0) {
cumsum[0] = 0; cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) { for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i] =
cumsum[i - 1] + cumsum[i - 1] +
CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) *
block_size; block_size;
} }
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
@ -244,26 +282,30 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
__syncthreads(); __syncthreads();
if (threadIdx.x < num_experts) { if (tid < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {
i += block_size) { expert_ids[i / block_size] = tid;
expert_ids[i / block_size] = threadIdx.x;
} }
} }
// Fill remaining expert_ids with 0 // Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { for (size_t i = fill_start_idx; i < expert_ids_size; i += stride) {
expert_ids[i] = 0; expert_ids[i] = 0;
} }
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i];
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid expert
if (expert_id == -1) continue;
}
int32_t rank_post_pad = int32_t rank_post_pad =
tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id]; ++tokens_cnts[tid * num_experts + expert_id];
} }
} }
@ -275,7 +317,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad,
std::optional<torch::Tensor> maybe_expert_map) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts = int64_t padded_num_experts =
@ -287,14 +330,19 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
// BlockScan uses 1024 threads and assigns one thread per expert. // BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024, TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024"); "padded_num_experts must be less than 1024");
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
bool has_expert_map = maybe_expert_map.has_value();
torch::Tensor expert_map;
if (has_expert_map) {
expert_map = maybe_expert_map.value();
} else {
expert_map = torch::empty({0}, options_int);
}
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors // calc needed amount of shared mem for `cumsum` tensors
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int);
bool small_batch_expert_mode = bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64); (topk_ids.numel() < 1024) && (num_experts <= 64);
@ -304,30 +352,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
((threads + 1) * num_experts + (num_experts + 1)) * ((threads + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t); sizeof(int32_t);
// threadIdx.x >= fill_threads: counting experts and aligning
// threadIdx.x < fill_threads: filling sorted_token_ids
constexpr int32_t fill_threads = 256;
auto small_batch_expert_kernel = auto small_batch_expert_kernel =
vllm::moe::moe_align_block_size_small_batch_expert_kernel< vllm::moe::moe_align_block_size_small_batch_expert_kernel<
scalar_t>; scalar_t, fill_threads>;
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( small_batch_expert_kernel<<<1, fill_threads + threads,
shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, num_tokens_post_pad.data_ptr<int32_t>(),
topk_ids.numel(), sorted_token_ids.size(0)); expert_map.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), sorted_token_ids.size(0), has_expert_map);
} else { } else {
torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int);
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>; auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = size_t shared_mem_size =
num_warps * experts_per_warp * sizeof(int32_t); num_warps * experts_per_warp * sizeof(int32_t);
align_kernel<<<1, threads, shared_mem_size, stream>>>( // launch two threadblocks
// blockIdx.x == 0: counting experts and aligning
// blockIdx.x == 1: filling sorted_token_ids
align_kernel<<<2, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, num_tokens_post_pad.data_ptr<int32_t>(),
padded_num_experts, experts_per_warp, block_size, expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts,
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(), experts_per_warp, block_size, topk_ids.numel(),
sorted_token_ids.size(0)); cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0),
has_expert_map);
const int block_threads = std::min(256, (int)threads); const int block_threads = std::min(256, (int)threads);
const int num_blocks = const int num_blocks =
@ -340,7 +399,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts); cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(),
topk_ids.numel(), num_experts, has_expert_map);
} }
}); });
} }

View File

@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad,
std::optional<torch::Tensor> maybe_expert_map);
void batched_moe_align_block_size(int64_t max_tokens_per_batch, void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size, int64_t block_size,

View File

@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts," "moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids," " int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids," " Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad,"
" Tensor? maybe_expert_map) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such // Aligning the number of tokens to be processed by each expert such

View File

@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck(): @pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4 num_experts = 4
block_size = 4 block_size = 4
expert_map = None
if ep_size != 1:
local_num_experts = num_experts // ep_size
expert_ids = torch.randint(
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
)
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
expert_map[expert_ids] = torch.arange(
local_num_experts, device="cuda", dtype=torch.int32
)
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
sorted_ids, sorted_ids,
expert_ids, expert_ids,
num_tokens_post_pad, num_tokens_post_pad,
expert_map,
), ),
) )

View File

@ -106,6 +106,8 @@ def torch_moe_align_block_size(
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids: if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
flattened_token_indices = torch.arange( flattened_token_indices = torch.arange(
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32 topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
@ -126,6 +128,8 @@ def torch_moe_align_block_size(
) )
for expert_id in range(num_experts): for expert_id in range(num_experts):
original_count = expert_token_counts[expert_id] original_count = expert_token_counts[expert_id]
if expert_map is not None and expert_map[expert_id] == -1:
continue
if original_count > 0: if original_count > 0:
expert_padded_counts[expert_id] = ( expert_padded_counts[expert_id] = (
(original_count + block_size - 1) // block_size (original_count + block_size - 1) // block_size
@ -143,6 +147,9 @@ def torch_moe_align_block_size(
current_pos = 0 current_pos = 0
current_block = 0 current_block = 0
for expert_id in range(num_experts): for expert_id in range(num_experts):
if expert_map is not None and expert_map[expert_id] == -1:
continue
expert_mask = sorted_expert_ids == expert_id expert_mask = sorted_expert_ids == expert_id
expert_tokens = sorted_token_indices[expert_mask] expert_tokens = sorted_token_indices[expert_mask]
num_expert_tokens = expert_tokens.shape[0] num_expert_tokens = expert_tokens.shape[0]
@ -153,7 +160,13 @@ def torch_moe_align_block_size(
) )
expert_blocks_needed = expert_padded_counts[expert_id] // block_size expert_blocks_needed = expert_padded_counts[expert_id] // block_size
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
expert_id_new = expert_id
if expert_map is not None:
expert_id_new = expert_map[expert_id]
expert_ids[current_block : current_block + expert_blocks_needed] = (
expert_id_new
)
current_pos += expert_padded_counts[expert_id] current_pos += expert_padded_counts[expert_id]
current_block += expert_blocks_needed current_block += expert_blocks_needed
@ -163,8 +176,6 @@ def torch_moe_align_block_size(
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device [total_padded_tokens], dtype=torch.int32, device=topk_ids.device
) )
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_token_ids, expert_ids, num_tokens_post_pad return sorted_token_ids, expert_ids, num_tokens_post_pad
@ -229,9 +240,9 @@ def test_moe_align_block_size(
) )
@pytest.mark.parametrize("m", [16, 32]) @pytest.mark.parametrize("m", [16, 32, 2048])
@pytest.mark.parametrize("topk", [2, 4]) @pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("num_experts", [8, 64])
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map( def test_moe_align_block_size_with_expert_map(
@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map(
block_size=block_size, block_size=block_size,
num_experts=num_experts, num_experts=num_experts,
expert_map=expert_map, expert_map=expert_map,
ignore_invalid_experts=True,
) )
golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
torch_moe_align_block_size( torch_moe_align_block_size(

View File

@ -1877,6 +1877,7 @@ def moe_align_block_size(
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor, experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None: ) -> None:
torch.ops._moe_C.moe_align_block_size( torch.ops._moe_C.moe_align_block_size(
topk_ids, topk_ids,
@ -1885,6 +1886,7 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
expert_map,
) )

View File

@ -316,7 +316,11 @@ def fused_marlin_moe(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts, expert_map topk_ids,
block_size_m,
global_num_experts,
expert_map,
ignore_invalid_experts=True,
) )
assert activation is not None assert activation is not None

View File

@ -1887,7 +1887,11 @@ def fused_experts_impl(
) )
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
) )
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
@ -1946,6 +1950,9 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if expert_map is not None:
intermediate_cache3.zero_()
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
qintermediate_cache2, qintermediate_cache2,
w2, w2,

View File

@ -14,6 +14,7 @@ def moe_align_block_size(
num_experts: int, num_experts: int,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False, pad_sorted_ids: bool = False,
ignore_invalid_experts: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
@ -35,7 +36,13 @@ def moe_align_block_size(
expert parallel shard. If the expert is not in the current expert expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1. parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length - pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size, should be padded to a multiple of block_size,
- ignore_invalid_experts: A flag indicating whether to ignore invalid
experts. When False, all expert_ids in topk_ids will participate in
counting and ranking, but invalid experts in expert_ids will be marked
as -1. When True, all invalid expert_ids in topk_ids will be ignored
and will not participate in counting or ranking, and there will be no
-1 in expert_ids.
Returns: Returns:
- sorted_token_ids: A tensor containing the sorted token indices according - sorted_token_ids: A tensor containing the sorted token indices according
@ -67,6 +74,10 @@ def moe_align_block_size(
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids: if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = min(
topk_ids.numel() * block_size, max_num_tokens_padded
)
sorted_ids = torch.empty( sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
) )
@ -77,9 +88,16 @@ def moe_align_block_size(
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size( ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map if ignore_invalid_experts else None,
) )
if expert_map is not None:
if expert_map is not None and not ignore_invalid_experts:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad