mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 05:31:49 +08:00
[Kernel][MoE] optimize moe_align_block_size (#29642)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
1b0482b9d1
commit
879ddb09c3
@ -24,12 +24,15 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
num_tokens_range = [1, 16, 256, 4096]
|
||||
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||
topk_range = [1, 2, 8]
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
ep_size_range = [1, 8]
|
||||
configs = list(
|
||||
itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_names=["num_tokens", "num_experts", "topk", "ep_size"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm"],
|
||||
@ -38,16 +41,26 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
def benchmark(num_tokens, num_experts, topk, ep_size, provider):
|
||||
"""Benchmark function for Triton."""
|
||||
block_size = 256
|
||||
torch.cuda.manual_seed_all(0)
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
e_map = None
|
||||
if ep_size != 1:
|
||||
local_e = num_experts // ep_size
|
||||
e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e]
|
||||
e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size(topk_ids, block_size, num_experts),
|
||||
lambda: moe_align_block_size(
|
||||
topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
|
||||
@ -83,14 +83,22 @@ template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts,
|
||||
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
|
||||
bool has_expert_map) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
// Use a separate threadblock to fill sorted_token_ids.
|
||||
// This is safe since the current kernel does not use sorted_token_ids.
|
||||
if (blockIdx.x == 1) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
|
||||
it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
@ -112,6 +120,11 @@ __global__ void moe_align_block_size_kernel(
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid experts
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
|
||||
@ -163,7 +176,8 @@ template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
size_t numel, int32_t num_experts) {
|
||||
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
|
||||
bool has_expert_map) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
@ -172,6 +186,11 @@ __global__ void count_and_sort_expert_tokens_kernel(
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid experts
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
}
|
||||
@ -193,50 +212,69 @@ __global__ void moe_sum_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, int32_t fill_threads>
|
||||
__global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
|
||||
size_t numel, int32_t max_num_tokens_padded, bool has_expert_map) {
|
||||
// Use an additional group of threads to fill sorted_token_ids.
|
||||
// Since the current kernel will use sorted_token_ids afterward,
|
||||
// we fill sorted_token_ids within the same threadblock to make
|
||||
// synchronization easier.
|
||||
if (threadIdx.x < fill_threads) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
|
||||
it += fill_threads) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
// Three __syncthreads() corresponding to the other threads
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
const size_t tid = threadIdx.x - fill_threads;
|
||||
const size_t stride = blockDim.x - fill_threads;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
|
||||
tokens_cnts[(tid + 1) * num_experts + i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid expert
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
++tokens_cnts[(tid + 1) * num_experts + expert_id];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[threadIdx.x] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[i * num_experts + threadIdx.x] +=
|
||||
tokens_cnts[(i - 1) * num_experts + threadIdx.x];
|
||||
if (tid < num_experts) {
|
||||
tokens_cnts[tid] = 0;
|
||||
for (int i = 1; i <= stride; ++i) {
|
||||
tokens_cnts[i * num_experts + tid] +=
|
||||
tokens_cnts[(i - 1) * num_experts + tid];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (tid == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] +
|
||||
CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) *
|
||||
CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
@ -244,26 +282,30 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
if (tid < num_experts) {
|
||||
for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {
|
||||
expert_ids[i / block_size] = tid;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
|
||||
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += stride) {
|
||||
expert_ids[i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid expert
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||
tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[threadIdx.x * num_experts + expert_id];
|
||||
++tokens_cnts[tid * num_experts + expert_id];
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,7 +317,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
std::optional<torch::Tensor> maybe_expert_map) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int64_t padded_num_experts =
|
||||
@ -287,14 +330,19 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
// BlockScan uses 1024 threads and assigns one thread per expert.
|
||||
TORCH_CHECK(padded_num_experts < 1024,
|
||||
"padded_num_experts must be less than 1024");
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
bool has_expert_map = maybe_expert_map.has_value();
|
||||
torch::Tensor expert_map;
|
||||
if (has_expert_map) {
|
||||
expert_map = maybe_expert_map.value();
|
||||
} else {
|
||||
expert_map = torch::empty({0}, options_int);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
bool small_batch_expert_mode =
|
||||
(topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
@ -304,30 +352,41 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
((threads + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// threadIdx.x >= fill_threads: counting experts and aligning
|
||||
// threadIdx.x < fill_threads: filling sorted_token_ids
|
||||
constexpr int32_t fill_threads = 256;
|
||||
auto small_batch_expert_kernel =
|
||||
vllm::moe::moe_align_block_size_small_batch_expert_kernel<
|
||||
scalar_t>;
|
||||
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
scalar_t, fill_threads>;
|
||||
small_batch_expert_kernel<<<1, fill_threads + threads,
|
||||
shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), sorted_token_ids.size(0));
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), sorted_token_ids.size(0), has_expert_map);
|
||||
} else {
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size =
|
||||
num_warps * experts_per_warp * sizeof(int32_t);
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
// launch two threadblocks
|
||||
// blockIdx.x == 0: counting experts and aligning
|
||||
// blockIdx.x == 1: filling sorted_token_ids
|
||||
align_kernel<<<2, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
|
||||
padded_num_experts, experts_per_warp, block_size,
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
|
||||
sorted_token_ids.size(0));
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts,
|
||||
experts_per_warp, block_size, topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0),
|
||||
has_expert_map);
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks =
|
||||
@ -340,7 +399,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
|
||||
cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(),
|
||||
topk_ids.numel(), num_experts, has_expert_map);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -11,7 +11,8 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
std::optional<torch::Tensor> maybe_expert_map);
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
|
||||
@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
" Tensor! num_tokens_post_pad,"
|
||||
" Tensor? maybe_expert_map) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
|
||||
@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
@pytest.mark.parametrize("ep_size", [1, 2])
|
||||
def test_moe_align_block_size_opcheck(ep_size):
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
|
||||
expert_map = None
|
||||
if ep_size != 1:
|
||||
local_num_experts = num_experts // ep_size
|
||||
expert_ids = torch.randint(
|
||||
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
|
||||
)
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
expert_map[expert_ids] = torch.arange(
|
||||
local_num_experts, device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -106,6 +106,8 @@ def torch_moe_align_block_size(
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
|
||||
flattened_token_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
|
||||
@ -126,6 +128,8 @@ def torch_moe_align_block_size(
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
original_count = expert_token_counts[expert_id]
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
if original_count > 0:
|
||||
expert_padded_counts[expert_id] = (
|
||||
(original_count + block_size - 1) // block_size
|
||||
@ -143,6 +147,9 @@ def torch_moe_align_block_size(
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
for expert_id in range(num_experts):
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
|
||||
expert_mask = sorted_expert_ids == expert_id
|
||||
expert_tokens = sorted_token_indices[expert_mask]
|
||||
num_expert_tokens = expert_tokens.shape[0]
|
||||
@ -153,7 +160,13 @@ def torch_moe_align_block_size(
|
||||
)
|
||||
|
||||
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
|
||||
|
||||
expert_id_new = expert_id
|
||||
if expert_map is not None:
|
||||
expert_id_new = expert_map[expert_id]
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = (
|
||||
expert_id_new
|
||||
)
|
||||
|
||||
current_pos += expert_padded_counts[expert_id]
|
||||
current_block += expert_blocks_needed
|
||||
@ -163,8 +176,6 @@ def torch_moe_align_block_size(
|
||||
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
return sorted_token_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@ -229,9 +240,9 @@ def test_moe_align_block_size(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32])
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("num_experts", [8])
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map(
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
|
||||
@ -1877,6 +1877,7 @@ def moe_align_block_size(
|
||||
sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
torch.ops._moe_C.moe_align_block_size(
|
||||
topk_ids,
|
||||
@ -1885,6 +1886,7 @@ def moe_align_block_size(
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -316,7 +316,11 @@ def fused_marlin_moe(
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts, expert_map
|
||||
topk_ids,
|
||||
block_size_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
assert activation is not None
|
||||
|
||||
@ -1887,7 +1887,11 @@ def fused_experts_impl(
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
@ -1946,6 +1950,9 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
|
||||
@ -14,6 +14,7 @@ def moe_align_block_size(
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
ignore_invalid_experts: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
@ -35,7 +36,13 @@ def moe_align_block_size(
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
should be padded to a multiple of block_size,
|
||||
- ignore_invalid_experts: A flag indicating whether to ignore invalid
|
||||
experts. When False, all expert_ids in topk_ids will participate in
|
||||
counting and ranking, but invalid experts in expert_ids will be marked
|
||||
as -1. When True, all invalid expert_ids in topk_ids will be ignored
|
||||
and will not participate in counting or ranking, and there will be no
|
||||
-1 in expert_ids.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@ -67,6 +74,10 @@ def moe_align_block_size(
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = min(
|
||||
topk_ids.numel() * block_size, max_num_tokens_padded
|
||||
)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
@ -77,9 +88,16 @@ def moe_align_block_size(
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
ops.moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map if ignore_invalid_experts else None,
|
||||
)
|
||||
if expert_map is not None:
|
||||
|
||||
if expert_map is not None and not ignore_invalid_experts:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user