mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 03:50:20 +08:00
Fix CUDA permute/unpermute for use with DeepGemm Moe (#17934)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
parent
bda9d0535f
commit
57c22e57f9
@ -8,12 +8,13 @@ import ray
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute,
|
||||
_moe_unpermute_and_reduce,
|
||||
moe_permute,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -63,18 +64,19 @@ def benchmark_permute(
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(
|
||||
@ -150,18 +152,19 @@ def benchmark_unpermute(
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = (
|
||||
moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
@ -191,16 +194,19 @@ def benchmark_unpermute(
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input
|
||||
(
|
||||
permuted_hidden_states,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = input
|
||||
output = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
output,
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts,
|
||||
)
|
||||
else:
|
||||
(
|
||||
@ -211,7 +217,11 @@ def benchmark_unpermute(
|
||||
inv_perm,
|
||||
) = input
|
||||
_moe_unpermute_and_reduce(
|
||||
output_hidden_states, permuted_hidden_states, inv_perm, topk_weights
|
||||
output_hidden_states,
|
||||
permuted_hidden_states,
|
||||
inv_perm,
|
||||
topk_weights,
|
||||
True,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
|
||||
@ -10,32 +10,28 @@
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indices, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor&
|
||||
permuted_input, // [topk * n_token/align_block_size_m, hidden]
|
||||
torch::Tensor& permuted_input, // [permuted_size, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx, // [permute_size]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
|
||||
"topk_weights must be float32");
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indices must be int32");
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
|
||||
"src_row_id2dst_row_id_map must be int32");
|
||||
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
|
||||
"inv_permuted_idx must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(
|
||||
src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
@ -46,8 +42,9 @@ void moe_permute(
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
@ -67,24 +64,22 @@ void moe_permute(
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
|
||||
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token,
|
||||
n_expert, n_local_expert, topk, sorter,
|
||||
get_ptr<int>(sort_workspace), stream);
|
||||
sortAndScanExpert(
|
||||
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
@ -101,32 +96,34 @@ void moe_permute(
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, // [n_token, topk]
|
||||
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>&
|
||||
expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t topk,
|
||||
torch::Tensor& hidden_states // [n_token, hidden]
|
||||
) {
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
|
||||
"topk_ids shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(
|
||||
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
|
||||
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
|
||||
"permuted_hidden_states dtype must be same as hidden_states");
|
||||
auto n_token = hidden_states.size(0);
|
||||
auto n_hidden = hidden_states.size(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const int64_t* valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
|
||||
int64_t const* valid_ptr = nullptr;
|
||||
if (expert_first_token_offset.has_value()) {
|
||||
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
|
||||
valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
|
||||
}
|
||||
|
||||
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
|
||||
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
|
||||
get_ptr<scalar_t>(permuted_hidden_states),
|
||||
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
|
||||
n_token, n_hidden, topk, valid_ptr, stream);
|
||||
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
|
||||
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
|
||||
@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream);
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream);
|
||||
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
|
||||
@ -2,10 +2,9 @@
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel(
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
static_cast<int>(expanded_dest_row);
|
||||
// skip non local expert token
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
permuted_idx[expanded_dest_row] = expanded_source_row;
|
||||
}
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel(
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
int64_t const source_row = expanded_source_row % num_rows;
|
||||
int64_t const source_row = expanded_source_row / k;
|
||||
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
|
||||
@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel(
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher(
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
|
||||
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
|
||||
align_block_size);
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts, align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
@ -128,11 +130,9 @@ template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr) {
|
||||
int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
|
||||
assert(orig_cols % 4 == 0);
|
||||
int64_t const original_row = blockIdx.x;
|
||||
int64_t const num_rows = gridDim.x;
|
||||
auto const offset = original_row * orig_cols;
|
||||
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
int64_t const num_valid = *num_valid_ptr;
|
||||
@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel(
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_original_row = original_row * k + k_idx;
|
||||
int64_t const expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = scales[k_offset];
|
||||
|
||||
// Check after row_rescale has accumulated
|
||||
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
|
||||
continue;
|
||||
}
|
||||
@ -189,9 +188,8 @@ template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream) {
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = 256;
|
||||
bool const check_finished = num_valid_ptr != nullptr;
|
||||
@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher(
|
||||
auto* const kernel = func_map[check_finished];
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, scales,
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
|
||||
num_valid_ptr);
|
||||
expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
|
||||
}
|
||||
|
||||
@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
" -> Tensor");
|
||||
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
|
||||
"moe_permute(Tensor input, Tensor topk_ids,"
|
||||
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
|
||||
"m_indices)->()");
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx, Tensor! m_indices)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||
"topk, Tensor! hidden_states)->()");
|
||||
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
|
||||
"int topk, Tensor! hidden_states)->()");
|
||||
|
||||
m.def("moe_permute_unpermute_supported() -> bool");
|
||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||
|
||||
@ -17,28 +17,34 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [16, 64]
|
||||
NUM_EXPERTS = [16, 64, 256]
|
||||
TOP_KS = [2, 4, 6, 8]
|
||||
EP_SIZE = [1, 4, 16]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
|
||||
def torch_permute(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
|
||||
def torch_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
if expert_map is not None:
|
||||
is_local_expert = (expert_map[topk_ids] != -1)
|
||||
not_local_expert = (expert_map[topk_ids] == -1)
|
||||
topk_ids = is_local_expert * (
|
||||
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
|
||||
token_expert_indices = torch.arange(0,
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).reshape(
|
||||
(n_token, topk))
|
||||
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
|
||||
stable=True)
|
||||
@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor,
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map %
|
||||
n_token, ...]
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map //
|
||||
topk, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(permuted_row_size,
|
||||
device="cuda",
|
||||
@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor,
|
||||
0, n_token * topk, device="cuda",
|
||||
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
dst_row_id2src_row_id_map[
|
||||
expert_first_token_offset[-1]:] = n_token * topk
|
||||
return [
|
||||
permuted_hidden_states, expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices, valid_row_idx
|
||||
src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices,
|
||||
valid_row_idx
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (topk * n_token + n_expert *
|
||||
(align_block_size - 1) + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
permuted_idx = torch.full((permuted_row_size, ),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
|
||||
device="cuda",
|
||||
dtype=hidden_states.dtype)
|
||||
@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor,
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset:first_token_offset +
|
||||
n_token_in_expert] % n_token
|
||||
first_token_offset:first_token_offset + n_token_in_expert]
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[align_first_token_offset:\
|
||||
align_first_token_offset+n_token_in_expert,\
|
||||
...] = hidden_states[\
|
||||
dst_row_id2src_row_id_in_expert, ...]
|
||||
dst_row_id2src_row_id_in_expert // topk,\
|
||||
...]
|
||||
permuted_idx[align_first_token_offset:\
|
||||
align_first_token_offset+\
|
||||
n_token_in_expert] = dst_row_id2src_row_id_in_expert
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor,
|
||||
src2dst_idx].reshape((n_token, topk))
|
||||
return [
|
||||
permuted_hidden_states, align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id, m_indices, valid_row_idx
|
||||
align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx
|
||||
]
|
||||
|
||||
|
||||
@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor, topk: int,
|
||||
n_expert: int) -> torch.Tensor:
|
||||
# ignore invalid row
|
||||
n_hidden = permuted_hidden_states.shape[1]
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0],
|
||||
dtype=bool,
|
||||
device="cuda")
|
||||
mask[valid_row_idx] = True
|
||||
permuted_hidden_states[~mask] = 0
|
||||
idx = src_row_id2dst_row_id_map.flatten()[
|
||||
token_expert_indices.flatten()].reshape(token_expert_indices.shape)
|
||||
output = permuted_hidden_states[idx, ...] * topk_weights[..., None]
|
||||
output = output.sum(dim=1).to(permuted_hidden_states.dtype)
|
||||
|
||||
permuted_hidden_states = permuted_hidden_states[
|
||||
src_row_id2dst_row_id_map.flatten(), ...]
|
||||
permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden)
|
||||
output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to(
|
||||
permuted_hidden_states.dtype)
|
||||
return output
|
||||
|
||||
|
||||
@ -184,43 +203,56 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, False)
|
||||
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
(gold_permuted_hidden_states, gold_expert_first_token_offset,
|
||||
gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices,
|
||||
valid_row_idx) = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
# token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
|
||||
result0, result1, result2, result3 = moe_permute(
|
||||
hidden_states, topk_weights, topk_ids, token_expert_indices, topk,
|
||||
n_expert, n_local_expert, expert_map, align_block_size,
|
||||
fill_invalid_expert)
|
||||
(permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx,
|
||||
m_indices) = moe_permute(hidden_states=hidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=n_expert,
|
||||
n_local_expert=n_local_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
|
||||
# check expert_first_token_offset
|
||||
torch.testing.assert_close(gold1, result1, atol=0, rtol=0)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(gold2, result2, atol=0, rtol=0)
|
||||
# check mindice
|
||||
torch.testing.assert_close(gold3, result3, atol=0, rtol=0)
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold0[valid_row_idx],
|
||||
result0[valid_row_idx],
|
||||
torch.testing.assert_close(gold_expert_first_token_offset,
|
||||
expert_first_token_offset,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(gold_inv_permuted_idx.flatten(),
|
||||
inv_permuted_idx,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
# check mindice
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0)
|
||||
|
||||
# add a random tensor to simulate group gemm
|
||||
result0 = 0.5 * result0 + torch.randn_like(result0)
|
||||
result0 = 0.5 * permuted_hidden_states + torch.randn_like(
|
||||
permuted_hidden_states)
|
||||
result4 = torch.empty_like(hidden_states)
|
||||
moe_unpermute(result4, result0, topk_weights, inv_permuted_idx,
|
||||
expert_first_token_offset)
|
||||
|
||||
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
|
||||
topk, n_expert, n_local_expert)
|
||||
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
|
||||
token_expert_indices, result2, valid_row_idx, topk,
|
||||
n_local_expert)
|
||||
|
||||
token_expert_indices, inv_permuted_idx,
|
||||
valid_row_idx, topk, n_local_expert)
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
|
||||
@ -76,43 +76,43 @@ def _moe_unpermute_and_reduce(
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
n_local_expert: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
This function expands and permutes activation to gather uncontinuous tokens
|
||||
for each expert.
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
||||
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
|
||||
- topk_ids (torch.Tensor): topk expert route id for each token.
|
||||
- token_expert_indices (torch.Tensor): indice for expanded hidden.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- n_expert (int): The number of expert.
|
||||
- n_local_expert (int): The number of expert in current EP rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||
to workaround DeepGemm unsupported -1 in m_indices
|
||||
Returns:
|
||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||
expert_first_token_offset will align up to 'align_block_size'.
|
||||
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
|
||||
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
|
||||
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
|
||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||
the group which the j-th row of the LHS belong to.`
|
||||
"""
|
||||
n_token, n_hidden = hidden_states.size()
|
||||
topk = topk_ids.size(1)
|
||||
assert (n_hidden * hidden_states.element_size()
|
||||
) % 16 == 0, "permue kernel need hidden dim align to 16B"
|
||||
permuted_row_size = n_token * topk
|
||||
@ -120,12 +120,19 @@ def moe_permute(
|
||||
permuted_row_size = (permuted_row_size + n_expert *
|
||||
(align_block_size - 1) + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
|
||||
if n_local_expert == -1:
|
||||
n_local_expert = n_expert
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.arange(0,
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).reshape(
|
||||
(n_token, topk))
|
||||
|
||||
m_indices = torch.full((permuted_row_size, ),
|
||||
fill_invalid_expert,
|
||||
dtype=torch.int32,
|
||||
@ -133,57 +140,54 @@ def moe_permute(
|
||||
expert_first_token_offset = torch.empty(n_local_expert + 1,
|
||||
dtype=torch.int64,
|
||||
device=hidden_states.device)
|
||||
src_row_id2dst_row_id_map = torch.empty((n_token, topk),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids,
|
||||
token_expert_indices, expert_map, n_expert,
|
||||
n_local_expert, topk, align_block_size,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices)
|
||||
return (permuted_hidden_states, expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices)
|
||||
permuted_idx = torch.full((permuted_row_size, ),
|
||||
n_token * topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
inv_permuted_idx = torch.empty((n_token, topk),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices,
|
||||
expert_map, n_expert, n_local_expert, topk,
|
||||
align_block_size, permuted_hidden_states,
|
||||
expert_first_token_offset, inv_permuted_idx,
|
||||
permuted_idx, m_indices)
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
|
||||
topk]
|
||||
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
|
||||
inv_permuted_idx.flatten(), m_indices)
|
||||
|
||||
|
||||
def moe_unpermute(
|
||||
out: torch.Tensor,
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
expert_first_token_offset: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
) -> torch.Tensor:
|
||||
inv_permuted_idx: torch.Tensor,
|
||||
expert_first_token_offset: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""
|
||||
This function expands and permutes activation to gathering uncontinuous
|
||||
tokens for each expert.
|
||||
Parameters:
|
||||
- out (torch.Tensor): output tensor
|
||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
||||
- topk_ids (torch.Tensor): topk expert route id for each token.
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for grouped gemm.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- n_expert (int): The number of expert.
|
||||
- n_local_expert (int): The number of expert in current EP rank.
|
||||
- inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute.
|
||||
- expert_first_token_offset (Optional[torch.Tensor]): offset of the first
|
||||
token of each expert for grouped gemm.
|
||||
Returns:
|
||||
- hidden_states (torch.Tensor): The reduced and unpermuted activation
|
||||
tensor.
|
||||
"""
|
||||
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
|
||||
topk = topk_weights.size(1)
|
||||
n_hidden = permuted_hidden_states.size(-1)
|
||||
assert (n_hidden * permuted_hidden_states.element_size()
|
||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||
hidden_states = torch.empty((n_token, n_hidden),
|
||||
dtype=permuted_hidden_states.dtype,
|
||||
device=permuted_hidden_states.device)
|
||||
|
||||
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
|
||||
topk_ids, src_row_id2dst_row_id_map,
|
||||
expert_first_token_offset, n_expert,
|
||||
n_local_expert, topk, hidden_states)
|
||||
return hidden_states
|
||||
inv_permuted_idx, expert_first_token_offset,
|
||||
topk, out)
|
||||
|
||||
|
||||
def moe_permute_unpermute_supported():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user