Fix CUDA permute/unpermute for use with DeepGemm Moe (#17934)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du 2025-07-27 22:08:00 +08:00 committed by GitHub
parent bda9d0535f
commit 57c22e57f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 238 additions and 211 deletions

View File

@ -8,12 +8,13 @@ import ray
import torch import torch
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute, _moe_permute,
_moe_unpermute_and_reduce, _moe_unpermute_and_reduce,
moe_permute,
moe_unpermute,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -63,18 +64,19 @@ def benchmark_permute(
def run(): def run():
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( (
moe_permute( permuted_hidden_states,
qhidden_states, a1q_scale,
topk_weights=topk_weights, first_token_off,
topk_ids=topk_ids, inv_perm_idx,
token_expert_indices=token_expert_indices, m_indices,
topk=topk, ) = moe_permute(
n_expert=num_experts, qhidden_states,
n_local_expert=num_experts, a1q_scale=None,
expert_map=None, topk_ids=topk_ids,
align_block_size=align_block_size, n_expert=num_experts,
) expert_map=None,
align_block_size=align_block_size,
) )
else: else:
( (
@ -150,18 +152,19 @@ def benchmark_unpermute(
def prepare(): def prepare():
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( (
moe_permute( permuted_hidden_states,
qhidden_states, a1q_scale,
topk_weights=topk_weights, first_token_off,
topk_ids=topk_ids, inv_perm_idx,
token_expert_indices=token_expert_indices, m_indices,
topk=topk, ) = moe_permute(
n_expert=num_experts, qhidden_states,
n_local_expert=num_experts, a1q_scale=None,
expert_map=None, topk_ids=topk_ids,
align_block_size=align_block_size, n_expert=num_experts,
) expert_map=None,
align_block_size=align_block_size,
) )
# convert to fp16/bf16 as gemm output # convert to fp16/bf16 as gemm output
return ( return (
@ -191,16 +194,19 @@ def benchmark_unpermute(
def run(input: tuple): def run(input: tuple):
if use_customized_permute: if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input (
permuted_hidden_states,
first_token_off,
inv_perm_idx,
m_indices,
) = input
output = torch.empty_like(hidden_states)
moe_unpermute( moe_unpermute(
output,
permuted_hidden_states, permuted_hidden_states,
topk_weights, topk_weights,
topk_ids,
inv_perm_idx, inv_perm_idx,
first_token_off, first_token_off,
topk,
num_experts,
num_experts,
) )
else: else:
( (
@ -211,7 +217,11 @@ def benchmark_unpermute(
inv_perm, inv_perm,
) = input ) = input
_moe_unpermute_and_reduce( _moe_unpermute_and_reduce(
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights output_hidden_states,
permuted_hidden_states,
inv_perm,
topk_weights,
True,
) )
# JIT compilation & warmup # JIT compilation & warmup

View File

@ -10,32 +10,28 @@
void moe_permute( void moe_permute(
const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk] const torch::Tensor& topk_ids, // [n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indices, // [n_token, topk] const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert] const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size, const std::optional<int64_t>& align_block_size,
torch::Tensor& torch::Tensor& permuted_input, // [permuted_size, hidden]
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m] torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64"); "expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32"); "topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32"); "token_expert_indices must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32"); "inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1") "expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK( TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), "token_expert_indices shape must be same as inv_permuted_idx");
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0]; auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1]; auto n_hidden = input.sizes()[1];
auto align_block_size_value = auto align_block_size_value =
@ -46,8 +42,9 @@ void moe_permute(
auto sort_workspace = torch::empty( auto sort_workspace = torch::empty(
{sorter_size}, {sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids); auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset = auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset); torch::zeros_like(expert_first_token_offset);
@ -67,24 +64,22 @@ void moe_permute(
const int* expert_map_ptr = get_ptr<int>(expert_map.value()); const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr = valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert; get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk, preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream); expert_map_ptr, n_expert, stream);
} }
// expert sort topk expert id and scan expert id get expert_first_token_offset // expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices), sortAndScanExpert(
get_ptr<int>(permuted_experts_id), get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<int>(dst_row_id2src_row_id_map), get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_expert, n_local_expert, topk, sorter, n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher // dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] { MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>( expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input), get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id), get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(dst_row_id2src_row_id_map), get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr, get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream); n_hidden, topk, n_local_expert, align_block_size_value, stream);
}); });
@ -101,32 +96,34 @@ void moe_permute(
} }
void moe_unpermute( void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk] const torch::Tensor& topk_weights, // [n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk] const torch::Tensor& inv_permuted_idx, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] const std::optional<torch::Tensor>&
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden] torch::Tensor& hidden_states // [n_token, hidden]
) { ) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK( TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map"); "permuted_hidden_states dtype must be same as hidden_states");
auto n_token = hidden_states.size(0); auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1); auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert; int64_t const* valid_ptr = nullptr;
if (expert_first_token_offset.has_value()) {
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
valid_ptr =
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
}
MOE_DISPATCH(hidden_states.scalar_type(), [&] { MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>( finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states), get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights), get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids), get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
n_token, n_hidden, topk, valid_ptr, stream); stream);
}); });
} }

View File

@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int tidx = threadIdx.x; int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[]; extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) { for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
} }
__syncthreads(); __syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; auto last_token_offset = smem_expert_first_token_offset[eidx + 1];

View File

@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream); int num_local_experts, const int& align_block_size, cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);
template <class T, class OutputType> template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher( void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows, int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, int64_t const* num_valid_ptr, cudaStream_t stream);
cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size, void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts, const int* expert_map_ptr, int num_experts,

View File

@ -2,10 +2,9 @@
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE> template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel( __global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k, int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) { int num_local_experts, int align_block_size) {
@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel(
assert(expanded_dest_row <= INT32_MAX); assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row); static_cast<int>(expanded_dest_row);
// skip non local expert token
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
permuted_idx[expanded_dest_row] = expanded_source_row;
}
} }
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel(
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>; using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows // Duplicate and permute rows
int64_t const source_row = expanded_source_row % num_rows; int64_t const source_row = expanded_source_row / k;
auto const* source_row_ptr = auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols); reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel(
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows, int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) { int num_local_experts, const int& align_block_size, cudaStream_t stream) {
@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher(
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>( func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row, expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset, expanded_source_row_to_expanded_dest_row, permuted_idx,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
align_block_size); num_local_experts, align_block_size);
} }
template <class T, class U> template <class T, class U>
@ -128,11 +130,9 @@ template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel( __global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0); assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x; int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols; auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr; int64_t const num_valid = *num_valid_ptr;
@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel(
ComputeElem thread_output; ComputeElem thread_output;
thread_output.fill(0); thread_output.fill(0);
for (int k_idx = 0; k_idx < k; ++k_idx) { for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_original_row = original_row * k + k_idx;
int64_t const expanded_permuted_row = int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row]; expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx; int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset]; float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue; continue;
} }
@ -189,9 +188,8 @@ template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher( void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row, float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows, int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, int64_t const* num_valid_ptr, cudaStream_t stream) {
cudaStream_t stream) {
int64_t const blocks = num_rows; int64_t const blocks = num_rows;
int64_t const threads = 256; int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr; bool const check_finished = num_valid_ptr != nullptr;
@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher(
auto* const kernel = func_map[check_finished]; auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>( kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales, expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
num_valid_ptr);
} }

View File

@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" -> Tensor"); " -> Tensor");
m.def( m.def(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," "moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert," "Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert," "int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! " "int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"m_indices)->()"); "permuted_idx, Tensor! m_indices)->()");
m.def( m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
"expert_first_token_offset, int n_expert, int n_local_expert,int " "int topk, Tensor! hidden_states)->()");
"topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool"); m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);

View File

@ -17,28 +17,34 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_permute_unpermute_supported, moe_unpermute) moe_permute, moe_permute_unpermute_supported, moe_unpermute)
from vllm.platforms import current_platform from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64] NUM_EXPERTS = [16, 64, 256]
TOP_KS = [2, 4, 6, 8] TOP_KS = [2, 4, 6, 8]
EP_SIZE = [1, 4, 16] EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0) current_platform.seed_everything(0)
def torch_permute(hidden_states: torch.Tensor, def torch_permute(
topk_ids: torch.Tensor, hidden_states: torch.Tensor,
token_expert_indices: torch.Tensor, topk_ids: torch.Tensor,
topk: int, # token_expert_indices: torch.Tensor,
n_expert: int, topk: int,
n_local_expert: int, n_expert: int,
start_expert: int, n_local_expert: int,
expert_map: Optional[torch.Tensor] = None, start_expert: int,
align_block_size: Optional[int] = None, expert_map: Optional[torch.Tensor] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]: align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None: if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1) is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1) not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * ( topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True) stable=True)
@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor,
valid_row_idx = [] valid_row_idx = []
if align_block_size is None: if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
n_token, ...] topk, ...]
permuted_row_size = permuted_hidden_states.shape[0] permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size, m_indices = torch.empty(permuted_row_size,
device="cuda", device="cuda",
@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor,
0, n_token * topk, device="cuda", 0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])] valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
dst_row_id2src_row_id_map[
expert_first_token_offset[-1]:] = n_token * topk
return [ return [
permuted_hidden_states, expert_first_token_offset, permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices, valid_row_idx src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
valid_row_idx
] ]
else: else:
permuted_row_size = (topk * n_token + n_expert * permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size - (align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size 1) // align_block_size * align_block_size
permuted_idx = torch.full((permuted_row_size, ),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device)
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda", device="cuda",
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor,
align_first_token_offset = align_expert_first_token_offset[i - 1] align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i] align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset + first_token_offset:first_token_offset + n_token_in_expert]
n_token_in_expert] % n_token
# store token in current expert with align_first_token_offset # store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\ permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\ align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\ ...] = hidden_states[\
dst_row_id2src_row_id_in_expert, ...] dst_row_id2src_row_id_in_expert // topk,\
...]
permuted_idx[align_first_token_offset:\
align_first_token_offset+\
n_token_in_expert] = dst_row_id2src_row_id_in_expert
# set current expert m_indices # set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1 m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [ valid_row_idx += [
@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor,
src2dst_idx].reshape((n_token, topk)) src2dst_idx].reshape((n_token, topk))
return [ return [
permuted_hidden_states, align_expert_first_token_offset, permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, m_indices, valid_row_idx align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
] ]
@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int, valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor: n_expert: int) -> torch.Tensor:
# ignore invalid row # ignore invalid row
n_hidden = permuted_hidden_states.shape[1]
mask = torch.zeros(permuted_hidden_states.shape[0], mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool, dtype=bool,
device="cuda") device="cuda")
mask[valid_row_idx] = True mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0 permuted_hidden_states[~mask] = 0
idx = src_row_id2dst_row_id_map.flatten()[
token_expert_indices.flatten()].reshape(token_expert_indices.shape) permuted_hidden_states = permuted_hidden_states[
output = permuted_hidden_states[idx, ...] * topk_weights[..., None] src_row_id2dst_row_id_map.flatten(), ...]
output = output.sum(dim=1).to(permuted_hidden_states.dtype) permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
permuted_hidden_states.dtype)
return output return output
@ -184,43 +203,56 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk( topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False) hidden_states, gating_output, topk, False)
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( (gold_permuted_hidden_states, gold_expert_first_token_offset,
hidden_states, gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
topk_ids, valid_row_idx) = torch_permute(
token_expert_indices, hidden_states,
topk, topk_ids,
n_expert, # token_expert_indices,
n_local_expert, topk,
start_expert, n_expert,
expert_map=expert_map, n_local_expert,
align_block_size=align_block_size, start_expert,
fill_invalid_expert=fill_invalid_expert) expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
result0, result1, result2, result3 = moe_permute( (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
hidden_states, topk_weights, topk_ids, token_expert_indices, topk, m_indices) = moe_permute(hidden_states=hidden_states,
n_expert, n_local_expert, expert_map, align_block_size, a1q_scale=None,
fill_invalid_expert) topk_ids=topk_ids,
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
# check expert_first_token_offset # check expert_first_token_offset
torch.testing.assert_close(gold1, result1, atol=0, rtol=0) torch.testing.assert_close(gold_expert_first_token_offset,
# check src_row_id2dst_row_id_map expert_first_token_offset,
torch.testing.assert_close(gold2, result2, atol=0, rtol=0) atol=0,
# check mindice rtol=0)
torch.testing.assert_close(gold3, result3, atol=0, rtol=0) # check src_row_id2dst_row_id_map
# check permuted_hidden_states, only valid token torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
torch.testing.assert_close(gold0[valid_row_idx], inv_permuted_idx,
result0[valid_row_idx], atol=0,
rtol=0)
# check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx],
atol=0, atol=0,
rtol=0) rtol=0)
# add a random tensor to simulate group gemm # add a random tensor to simulate group gemm
result0 = 0.5 * result0 + torch.randn_like(result0) result0 = 0.5 * permuted_hidden_states + torch.randn_like(
permuted_hidden_states)
result4 = torch.empty_like(hidden_states)
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
expert_first_token_offset)
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
topk, n_expert, n_local_expert)
gold4 = torch_unpermute(result0, topk_weights, topk_ids, gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, result2, valid_row_idx, topk, token_expert_indices, inv_permuted_idx,
n_local_expert) valid_row_idx, topk, n_local_expert)
# check unpermuted hidden # check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)

View File

@ -76,43 +76,43 @@ def _moe_unpermute_and_reduce(
def moe_permute( def moe_permute(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
topk: int,
n_expert: int, n_expert: int,
n_local_expert: int, n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None, align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1 fill_invalid_expert: int = -1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]:
""" """
This function expands and permutes activation to gather uncontinuous tokens This function expands and permutes activation to gather uncontinuous tokens
for each expert. for each expert.
Parameters: Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- topk_weights (torch.Tensor): topk expert route weight for each token. - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- topk_ids (torch.Tensor): topk expert route id for each token. - topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert. - n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank. - n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert from the global expert space to the local expert space of the expert
parallel shard. parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm - align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert - fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices to workaround DeepGemm unsupported -1 in m_indices
Returns: Returns:
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size' of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'. expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.` the group which the j-th row of the LHS belong to.`
""" """
n_token, n_hidden = hidden_states.size() n_token, n_hidden = hidden_states.size()
topk = topk_ids.size(1)
assert (n_hidden * hidden_states.element_size() assert (n_hidden * hidden_states.element_size()
) % 16 == 0, "permue kernel need hidden dim align to 16B" ) % 16 == 0, "permue kernel need hidden dim align to 16B"
permuted_row_size = n_token * topk permuted_row_size = n_token * topk
@ -120,12 +120,19 @@ def moe_permute(
permuted_row_size = (permuted_row_size + n_expert * permuted_row_size = (permuted_row_size + n_expert *
(align_block_size - 1) + align_block_size - (align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size 1) // align_block_size * align_block_size
if n_local_expert == -1:
n_local_expert = n_expert
permuted_hidden_states = torch.empty( permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), (permuted_row_size, n_hidden),
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device, device=hidden_states.device,
) )
token_expert_indices = torch.arange(0,
n_token * topk,
dtype=torch.int32,
device=hidden_states.device).reshape(
(n_token, topk))
m_indices = torch.full((permuted_row_size, ), m_indices = torch.full((permuted_row_size, ),
fill_invalid_expert, fill_invalid_expert,
dtype=torch.int32, dtype=torch.int32,
@ -133,57 +140,54 @@ def moe_permute(
expert_first_token_offset = torch.empty(n_local_expert + 1, expert_first_token_offset = torch.empty(n_local_expert + 1,
dtype=torch.int64, dtype=torch.int64,
device=hidden_states.device) device=hidden_states.device)
src_row_id2dst_row_id_map = torch.empty((n_token, topk), permuted_idx = torch.full((permuted_row_size, ),
dtype=torch.int32, n_token * topk,
device=hidden_states.device) dtype=torch.int32,
torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, device=hidden_states.device)
token_expert_indices, expert_map, n_expert, inv_permuted_idx = torch.empty((n_token, topk),
n_local_expert, topk, align_block_size, dtype=torch.int32,
permuted_hidden_states, device=hidden_states.device)
expert_first_token_offset, topk_ids = topk_ids.to(torch.int32)
src_row_id2dst_row_id_map, m_indices) torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices,
return (permuted_hidden_states, expert_first_token_offset, expert_map, n_expert, n_local_expert, topk,
src_row_id2dst_row_id_map, m_indices) align_block_size, permuted_hidden_states,
expert_first_token_offset, inv_permuted_idx,
permuted_idx, m_indices)
if a1q_scale is not None:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
topk]
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
inv_permuted_idx.flatten(), m_indices)
def moe_unpermute( def moe_unpermute(
out: torch.Tensor,
permuted_hidden_states: torch.Tensor, permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, inv_permuted_idx: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor, expert_first_token_offset: Optional[torch.Tensor] = None,
expert_first_token_offset: torch.Tensor, ) -> None:
topk: int,
n_expert: int,
n_local_expert: int,
) -> torch.Tensor:
""" """
This function expands and permutes activation to gathering uncontinuous This function expands and permutes activation to gathering uncontinuous
tokens for each expert. tokens for each expert.
Parameters: Parameters:
- out (torch.Tensor): output tensor
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
- topk_weights (torch.Tensor): topk expert route weight for each token. - topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute.
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (Optional[torch.Tensor]): offset of the first
of each expert for grouped gemm. token of each expert for grouped gemm.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
Returns: Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation - hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor. tensor.
""" """
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) topk = topk_weights.size(1)
n_hidden = permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size() assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B" ) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
hidden_states = torch.empty((n_token, n_hidden),
dtype=permuted_hidden_states.dtype,
device=permuted_hidden_states.device)
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
topk_ids, src_row_id2dst_row_id_map, inv_permuted_idx, expert_first_token_offset,
expert_first_token_offset, n_expert, topk, out)
n_local_expert, topk, hidden_states)
return hidden_states
def moe_permute_unpermute_supported(): def moe_permute_unpermute_supported():