Fix CUDA permute/unpermute for use with DeepGemm Moe (#17934)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du 2025-07-27 22:08:00 +08:00 committed by GitHub
parent bda9d0535f
commit 57c22e57f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 238 additions and 211 deletions

View File

@ -8,12 +8,13 @@ import ray
import torch import torch
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute, _moe_permute,
_moe_unpermute_and_reduce, _moe_unpermute_and_reduce,
moe_permute,
moe_unpermute,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -63,18 +64,19 @@ def benchmark_permute(
def run(): def run():
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( (
moe_permute( permuted_hidden_states,
qhidden_states, a1q_scale,
topk_weights=topk_weights, first_token_off,
topk_ids=topk_ids, inv_perm_idx,
token_expert_indices=token_expert_indices, m_indices,
topk=topk, ) = moe_permute(
n_expert=num_experts, qhidden_states,
n_local_expert=num_experts, a1q_scale=None,
expert_map=None, topk_ids=topk_ids,
align_block_size=align_block_size, n_expert=num_experts,
) expert_map=None,
align_block_size=align_block_size,
) )
else: else:
( (
@ -150,18 +152,19 @@ def benchmark_unpermute(
def prepare(): def prepare():
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( (
moe_permute( permuted_hidden_states,
qhidden_states, a1q_scale,
topk_weights=topk_weights, first_token_off,
topk_ids=topk_ids, inv_perm_idx,
token_expert_indices=token_expert_indices, m_indices,
topk=topk, ) = moe_permute(
n_expert=num_experts, qhidden_states,
n_local_expert=num_experts, a1q_scale=None,
expert_map=None, topk_ids=topk_ids,
align_block_size=align_block_size, n_expert=num_experts,
) expert_map=None,
align_block_size=align_block_size,
) )
# convert to fp16/bf16 as gemm output # convert to fp16/bf16 as gemm output
return ( return (
@ -191,16 +194,19 @@ def benchmark_unpermute(
def run(input: tuple): def run(input: tuple):
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input (
permuted_hidden_states,
first_token_off,
inv_perm_idx,
m_indices,
) = input
output = torch.empty_like(hidden_states)
moe_unpermute( moe_unpermute(
output,
permuted_hidden_states, permuted_hidden_states,
topk_weights, topk_weights,
topk_ids,
inv_perm_idx, inv_perm_idx,
first_token_off, first_token_off,
topk,
num_experts,
num_experts,
) )
else: else:
( (
@ -211,7 +217,11 @@ def benchmark_unpermute(
inv_perm, inv_perm,
) = input ) = input
_moe_unpermute_and_reduce( _moe_unpermute_and_reduce(
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
) )
# JIT compilation & warmup # JIT compilation & warmup

View File

@ -10,32 +10,28 @@
void moe_permute( void moe_permute(
const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk] const torch::Tensor& topk_ids, // [n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indices, // [n_token, topk] const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert] const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size, const std::optional<int64_t>& align_block_size,
torch::Tensor& torch::Tensor& permuted_input, // [permuted_size, hidden]
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m] torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64"); "expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32"); "topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32"); "token_expert_indices must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32"); "inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1") "expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK( TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), "token_expert_indices shape must be same as inv_permuted_idx");
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0]; auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1]; auto n_hidden = input.sizes()[1];
auto align_block_size_value = auto align_block_size_value =
@ -46,8 +42,9 @@ void moe_permute(
auto sort_workspace = torch::empty( auto sort_workspace = torch::empty(
{sorter_size}, {sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids); auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset = auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset); torch::zeros_like(expert_first_token_offset);
@ -67,24 +64,22 @@ void moe_permute(
const int* expert_map_ptr = get_ptr<int>(expert_map.value()); const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr = valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert; get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk, preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream); expert_map_ptr, n_expert, stream);
} }
// expert sort topk expert id and scan expert id get expert_first_token_offset // expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices), sortAndScanExpert(
get_ptr<int>(permuted_experts_id), get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<int>(dst_row_id2src_row_id_map), get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_expert, n_local_expert, topk, sorter, n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher // dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] { MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>( expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input), get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id), get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(dst_row_id2src_row_id_map), get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr, get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream); n_hidden, topk, n_local_expert, align_block_size_value, stream);
}); });
@ -101,32 +96,34 @@ void moe_permute(
} }
void moe_unpermute( void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk] const torch::Tensor& topk_weights, // [n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk] const torch::Tensor& inv_permuted_idx, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] const std::optional<torch::Tensor>&
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden] torch::Tensor& hidden_states // [n_token, hidden]
) { ) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK( TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map"); "permuted_hidden_states dtype must be same as hidden_states");
auto n_token = hidden_states.size(0); auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1); auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert; int64_t const* valid_ptr = nullptr;
if (expert_first_token_offset.has_value()) {
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
valid_ptr =
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
}
MOE_DISPATCH(hidden_states.scalar_type(), [&] { MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>( finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states), get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights), get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids), get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
n_token, n_hidden, topk, valid_ptr, stream); stream);
}); });
} }

View File

@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int tidx = threadIdx.x; int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[]; extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) { for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
} }
__syncthreads(); __syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; auto last_token_offset = smem_expert_first_token_offset[eidx + 1];

View File

@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream); int num_local_experts, const int& align_block_size, cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);
template <class T, class OutputType> template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher( void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows, int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, int64_t const* num_valid_ptr, cudaStream_t stream);
cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size, void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts, const int* expert_map_ptr, int num_experts,

View File

@ -2,10 +2,9 @@
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE> template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel( __global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k, int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) { int num_local_experts, int align_block_size) {
@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel(
assert(expanded_dest_row <= INT32_MAX); assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row); static_cast<int>(expanded_dest_row);
// skip non local expert token
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
permuted_idx[expanded_dest_row] = expanded_source_row;
}
} }
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel(
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>; using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows // Duplicate and permute rows
int64_t const source_row = expanded_source_row % num_rows; int64_t const source_row = expanded_source_row / k;
auto const* source_row_ptr = auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols); reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel(
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) { int num_local_experts, const int& align_block_size, cudaStream_t stream) {
@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher(
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>( func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row, expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset, expanded_source_row_to_expanded_dest_row, permuted_idx,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
align_block_size); num_local_experts, align_block_size);
} }
template <class T, class U> template <class T, class U>
@ -128,11 +130,9 @@ template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel( __global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0); assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x; int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols; auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr; int64_t const num_valid = *num_valid_ptr;
@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel(
ComputeElem thread_output; ComputeElem thread_output;
thread_output.fill(0); thread_output.fill(0);
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_original_row = original_row * k + k_idx;
int64_t const expanded_permuted_row = int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row]; expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx; int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset]; float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue; continue;
} }
@ -189,9 +188,8 @@ template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher( void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows, int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, int64_t const* num_valid_ptr, cudaStream_t stream) {
cudaStream_t stream) {
int64_t const blocks = num_rows; int64_t const blocks = num_rows;
int64_t const threads = 256; int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr; bool const check_finished = num_valid_ptr != nullptr;
@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher(
auto* const kernel = func_map[check_finished]; auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>( kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales, expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
num_valid_ptr);
} }

View File

@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" -> Tensor"); " -> Tensor");
m.def( m.def(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," "moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert," "Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert," "int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! " "int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"m_indices)->()"); "permuted_idx, Tensor! m_indices)->()");
m.def( m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
"expert_first_token_offset, int n_expert, int n_local_expert,int " "int topk, Tensor! hidden_states)->()");
"topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool"); m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);

View File

@ -17,28 +17,34 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute_unpermute_supported, moe_unpermute) moe_permute, moe_permute_unpermute_supported, moe_unpermute)
from vllm.platforms import current_platform from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64] NUM_EXPERTS = [16, 64, 256]
TOP_KS = [2, 4, 6, 8] TOP_KS = [2, 4, 6, 8]
EP_SIZE = [1, 4, 16] EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0) current_platform.seed_everything(0)
def torch_permute(hidden_states: torch.Tensor, def torch_permute(
topk_ids: torch.Tensor, hidden_states: torch.Tensor,
token_expert_indices: torch.Tensor, topk_ids: torch.Tensor,
topk: int, # token_expert_indices: torch.Tensor,
n_expert: int, topk: int,
n_local_expert: int, n_expert: int,
start_expert: int, n_local_expert: int,
expert_map: Optional[torch.Tensor] = None, start_expert: int,
align_block_size: Optional[int] = None, expert_map: Optional[torch.Tensor] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]: align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None: if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1) is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1) not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * ( topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True) stable=True)
@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor,
valid_row_idx = [] valid_row_idx = []
if align_block_size is None: if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
n_token, ...] topk, ...]
permuted_row_size = permuted_hidden_states.shape[0] permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size, m_indices = torch.empty(permuted_row_size,
device="cuda", device="cuda",
@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor,
0, n_token * topk, device="cuda", 0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])] valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[
expert_first_token_offset[-1]:] = n_token * topk
return [ return [
permuted_hidden_states, expert_first_token_offset, permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices, valid_row_idx src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
valid_row_idx
] ]
else: else:
permuted_row_size = (topk * n_token + n_expert * permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size - (align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size 1) // align_block_size * align_block_size
permuted_idx = torch.full((permuted_row_size, ),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device)
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda", device="cuda",
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor,
align_first_token_offset = align_expert_first_token_offset[i - 1] align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i] align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset + first_token_offset:first_token_offset + n_token_in_expert]
n_token_in_expert] % n_token
# store token in current expert with align_first_token_offset # store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\ permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\ align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\ ...] = hidden_states[\
dst_row_id2src_row_id_in_expert, ...] dst_row_id2src_row_id_in_expert // topk,\
...]
permuted_idx[align_first_token_offset:\
align_first_token_offset+\
n_token_in_expert] = dst_row_id2src_row_id_in_expert
# set current expert m_indices # set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1 m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [ valid_row_idx += [
@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor,
src2dst_idx].reshape((n_token, topk)) src2dst_idx].reshape((n_token, topk))
return [ return [
permuted_hidden_states, align_expert_first_token_offset, permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, m_indices, valid_row_idx align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
] ]
@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int, valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor: n_expert: int) -> torch.Tensor:
# ignore invalid row # ignore invalid row
n_hidden = permuted_hidden_states.shape[1]
mask = torch.zeros(permuted_hidden_states.shape[0], mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool, dtype=bool,
device="cuda") device="cuda")
mask[valid_row_idx] = True mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0 permuted_hidden_states[~mask] = 0
idx = src_row_id2dst_row_id_map.flatten()[
token_expert_indices.flatten()].reshape(token_expert_indices.shape) permuted_hidden_states = permuted_hidden_states[
output = permuted_hidden_states[idx, ...] * topk_weights[..., None] src_row_id2dst_row_id_map.flatten(), ...]
output = output.sum(dim=1).to(permuted_hidden_states.dtype) permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
permuted_hidden_states.dtype)
return output return output
@ -184,43 +203,56 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False) hidden_states, gating_output, topk, False)
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( (gold_permuted_hidden_states, gold_expert_first_token_offset,
hidden_states, gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
topk_ids, valid_row_idx) = torch_permute(
token_expert_indices, hidden_states,
topk, topk_ids,
n_expert, # token_expert_indices,
n_local_expert, topk,
start_expert, n_expert,
expert_map=expert_map, n_local_expert,
align_block_size=align_block_size, start_expert,
fill_invalid_expert=fill_invalid_expert) expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
result0, result1, result2, result3 = moe_permute( (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
hidden_states, topk_weights, topk_ids, token_expert_indices, topk, m_indices) = moe_permute(hidden_states=hidden_states,
n_expert, n_local_expert, expert_map, align_block_size, a1q_scale=None,
fill_invalid_expert) topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
# check expert_first_token_offset # check expert_first_token_offset
torch.testing.assert_close(gold1, result1, atol=0, rtol=0) torch.testing.assert_close(gold_expert_first_token_offset,
# check src_row_id2dst_row_id_map expert_first_token_offset,
torch.testing.assert_close(gold2, result2, atol=0, rtol=0) atol=0,
# check mindice rtol=0)
torch.testing.assert_close(gold3, result3, atol=0, rtol=0) # check src_row_id2dst_row_id_map
# check permuted_hidden_states, only valid token torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
torch.testing.assert_close(gold0[valid_row_idx], inv_permuted_idx,
result0[valid_row_idx], atol=0,
rtol=0)
# check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0, atol=0,
rtol=0) rtol=0)
# add a random tensor to simulate group gemm # add a random tensor to simulate group gemm
result0 = 0.5 * result0 + torch.randn_like(result0) result0 = 0.5 * permuted_hidden_states + torch.randn_like(
permuted_hidden_states)
result4 = torch.empty_like(hidden_states)
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
expert_first_token_offset)
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
topk, n_expert, n_local_expert)
gold4 = torch_unpermute(result0, topk_weights, topk_ids, gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, result2, valid_row_idx, topk, token_expert_indices, inv_permuted_idx,
n_local_expert) valid_row_idx, topk, n_local_expert)
# check unpermuted hidden # check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)

View File

@ -76,25 +76,22 @@ def _moe_unpermute_and_reduce(
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
topk: int,
n_expert: int, n_expert: int,
n_local_expert: int, n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None, align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1 fill_invalid_expert: int = -1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]:
""" """
This function expands and permutes activation to gather uncontinuous tokens This function expands and permutes activation to gather uncontinuous tokens
for each expert. for each expert.
Parameters: Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- topk_weights (torch.Tensor): topk expert route weight for each token. - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- topk_ids (torch.Tensor): topk expert route id for each token. - topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert. - n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank. - n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
@ -105,14 +102,17 @@ def moe_permute(
to workaround DeepGemm unsupported -1 in m_indices to workaround DeepGemm unsupported -1 in m_indices
Returns: Returns:
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size' of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'. expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.` the group which the j-th row of the LHS belong to.`
""" """
n_token, n_hidden = hidden_states.size() n_token, n_hidden = hidden_states.size()
topk = topk_ids.size(1)
assert (n_hidden * hidden_states.element_size() assert (n_hidden * hidden_states.element_size()
) % 16 == 0, "permue kernel need hidden dim align to 16B" ) % 16 == 0, "permue kernel need hidden dim align to 16B"
permuted_row_size = n_token * topk permuted_row_size = n_token * topk
@ -120,12 +120,19 @@ def moe_permute(
permuted_row_size = (permuted_row_size + n_expert * permuted_row_size = (permuted_row_size + n_expert *
(align_block_size - 1) + align_block_size - (align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size 1) // align_block_size * align_block_size
if n_local_expert == -1:
n_local_expert = n_expert
permuted_hidden_states = torch.empty( permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), (permuted_row_size, n_hidden),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
m_indices = torch.full((permuted_row_size, ), m_indices = torch.full((permuted_row_size, ),
fill_invalid_expert, fill_invalid_expert,
dtype=torch.int32, dtype=torch.int32,
@ -133,57 +140,54 @@ def moe_permute(
expert_first_token_offset = torch.empty(n_local_expert + 1, expert_first_token_offset = torch.empty(n_local_expert + 1,
dtype=torch.int64, dtype=torch.int64,
device=hidden_states.device) device=hidden_states.device)
src_row_id2dst_row_id_map = torch.empty((n_token, topk), permuted_idx = torch.full((permuted_row_size, ),
dtype=torch.int32, n_token * topk,
device=hidden_states.device) dtype=torch.int32,
torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, device=hidden_states.device)
token_expert_indices, expert_map, n_expert, inv_permuted_idx = torch.empty((n_token, topk),
n_local_expert, topk, align_block_size, dtype=torch.int32,
permuted_hidden_states, device=hidden_states.device)
expert_first_token_offset, topk_ids = topk_ids.to(torch.int32)
src_row_id2dst_row_id_map, m_indices) torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices,
return (permuted_hidden_states, expert_first_token_offset, expert_map, n_expert, n_local_expert, topk,
src_row_id2dst_row_id_map, m_indices) align_block_size, permuted_hidden_states,
expert_first_token_offset, inv_permuted_idx,
permuted_idx, m_indices)
if a1q_scale is not None:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
topk]
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
inv_permuted_idx.flatten(), m_indices)
def moe_unpermute( def moe_unpermute(
out: torch.Tensor,
permuted_hidden_states: torch.Tensor, permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, inv_permuted_idx: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor, expert_first_token_offset: Optional[torch.Tensor] = None,
expert_first_token_offset: torch.Tensor, ) -> None:
topk: int,
n_expert: int,
n_local_expert: int,
) -> torch.Tensor:
""" """
This function expands and permutes activation to gathering uncontinuous This function expands and permutes activation to gathering uncontinuous
tokens for each expert. tokens for each expert.
Parameters: Parameters:
- out (torch.Tensor): output tensor
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
- topk_weights (torch.Tensor): topk expert route weight for each token. - topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute.
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (Optional[torch.Tensor]): offset of the first
of each expert for grouped gemm. token of each expert for grouped gemm.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
Returns: Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation - hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor. tensor.
""" """
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) topk = topk_weights.size(1)
n_hidden = permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size() assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B" ) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
hidden_states = torch.empty((n_token, n_hidden),
dtype=permuted_hidden_states.dtype,
device=permuted_hidden_states.device)
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
topk_ids, src_row_id2dst_row_id_map, inv_permuted_idx, expert_first_token_offset,
expert_first_token_offset, n_expert, topk, out)
n_local_expert, topk, hidden_states)
return hidden_states
def moe_permute_unpermute_supported(): def moe_permute_unpermute_supported():