mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:57:15 +08:00
permute/unpermute kernel for moe optimization (#14568)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
parent
0f87d8f7b2
commit
3e887d2e0c
@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
|
||||
|
||||
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
|
||||
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
|
||||
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
||||
|
||||
@ -682,6 +681,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(MOE_PERMUTE_SRC
|
||||
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
|
||||
"csrc/moe/moe_permute_unpermute_op.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_PERMUTE_SRC}"
|
||||
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
|
||||
endif()
|
||||
message(STATUS "Enabling moe extension.")
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
@ -690,6 +700,8 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, renormalize=False)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
||||
@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig,
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||
False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
|
||||
349
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
349
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_moe_permute, _moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_permute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
# output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
dtype=torch.float32)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def benchmark_unpermute(num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False)
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = moe_permute(
|
||||
qhidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
token_expert_indices=token_expert_indices,
|
||||
topk=topk,
|
||||
n_expert=num_experts,
|
||||
n_local_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_hidden_states.to(dtype), first_token_off,
|
||||
inv_perm_idx, m_indices)
|
||||
else:
|
||||
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
|
||||
num_experts, None, align_block_size)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (permuted_qhidden_states.to(dtype), a1q_scale,
|
||||
sorted_token_ids, expert_ids, inv_perm)
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(permuted_hidden_states, first_token_off, inv_perm_idx,
|
||||
m_indices) = input
|
||||
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
|
||||
inv_perm_idx, first_token_off, topk, num_experts,
|
||||
num_experts)
|
||||
else:
|
||||
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = input
|
||||
_moe_unpermute_and_reduce(output_hidden_states,
|
||||
permuted_hidden_states, inv_perm,
|
||||
topk_weights)
|
||||
|
||||
# JIT compilation & warmup
|
||||
input = prepare()
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_customized_permute: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
|
||||
permute_time = benchmark_permute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
unpermute_time = benchmark_unpermute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute)
|
||||
return permute_time, unpermute_time
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif config.architectures[0] in [
|
||||
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
|
||||
]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_customized_permute = args.use_customized_permute
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_customized_permute)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}")
|
||||
print(f"Permute time: {permute:.2f} us")
|
||||
print(f"Unpermute time: {unpermute:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--use-customized-permute", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
133
csrc/moe/moe_permute_unpermute_op.cu
Normal file
133
csrc/moe/moe_permute_unpermute_op.cu
Normal file
@ -0,0 +1,133 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
|
||||
#include "permute_unpermute_kernels/dispatch.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indicies, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor&
|
||||
permuted_input, // [topk * n_token/align_block_size_m, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
|
||||
"topk_weights must be float32");
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indicies must be int32");
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
|
||||
"src_row_id2dst_row_id_map must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(
|
||||
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
|
||||
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
align_block_size.has_value() ? align_block_size.value() : -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const long sorter_size =
|
||||
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
// pre-process kernel for expert-parallelism:
|
||||
// no local expert id plus "n_expert" offset for priority to local expert
|
||||
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
|
||||
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
|
||||
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
|
||||
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
|
||||
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
|
||||
// operation is to make local expert high priority in following sort topk_ids
|
||||
// and scan local expert_first_token_offset for each ep rank for next group
|
||||
// gemm.
|
||||
if (expert_map.has_value()) {
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
|
||||
get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token,
|
||||
n_expert, n_local_expert, topk, sorter,
|
||||
get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
|
||||
get_ptr<int>(dst_row_id2src_row_id_map),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, //[n_token, topk]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
|
||||
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
torch::Tensor& hidden_states // [n_token, hidden]
|
||||
) {
|
||||
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
|
||||
"topk_ids shape must be same as src_row_id2dst_row_id_map");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(
|
||||
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
|
||||
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
|
||||
auto n_token = hidden_states.size(0);
|
||||
auto n_hidden = hidden_states.size(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const int64_t* valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
|
||||
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
|
||||
get_ptr<scalar_t>(permuted_hidden_states),
|
||||
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
|
||||
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
|
||||
n_token, n_hidden, topk, valid_ptr, stream);
|
||||
});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
53
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
53
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
@ -0,0 +1,53 @@
|
||||
#pragma once
|
||||
#include <cuda_fp8.h>
|
||||
#define MOE_SWITCH(TYPE, ...) \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
|
||||
}
|
||||
|
||||
#define MOE_DISPATCH_CASE(enum_type, ...) \
|
||||
case enum_type: { \
|
||||
using scalar_t = ScalarType2CudaType<enum_type>::type; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
}
|
||||
#define MOE_DISPATCH_FLOAT_CASE(...) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
|
||||
#define MOE_DISPATCH(TYPE, ...) \
|
||||
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
|
||||
|
||||
template <at::ScalarType type>
|
||||
struct ScalarType2CudaType;
|
||||
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Half> {
|
||||
using type = half;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
|
||||
// #if __CUDA_ARCH__ >= 890
|
||||
// fp8
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
|
||||
using type = __nv_fp8_e5m2;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
};
|
||||
// #endif
|
||||
@ -0,0 +1,229 @@
|
||||
|
||||
#include "moe_permute_unpermute_kernel.h"
|
||||
|
||||
// CubKeyValueSorter definition begin
|
||||
CubKeyValueSorter::CubKeyValueSorter()
|
||||
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||
|
||||
int CubKeyValueSorter::expertsToBits(int num_experts) {
|
||||
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
|
||||
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
|
||||
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
|
||||
}
|
||||
|
||||
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
|
||||
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
|
||||
|
||||
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = expertsToBits(num_experts);
|
||||
}
|
||||
|
||||
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts) {
|
||||
int num_bits = expertsToBits(num_experts);
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
|
||||
null_int, null_int, num_key_value_pairs, 0,
|
||||
num_bits);
|
||||
|
||||
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
|
||||
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
|
||||
// inputs
|
||||
if (required_storage == 0) {
|
||||
required_storage = 1;
|
||||
}
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
|
||||
int const* keys_in, int* keys_out,
|
||||
int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs,
|
||||
cudaStream_t stream) {
|
||||
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
|
||||
size_t actual_ws_size = workspace_size;
|
||||
|
||||
TORCH_CHECK(expected_ws_size <= workspace_size,
|
||||
"[CubKeyValueSorter::run] The allocated workspace is too small "
|
||||
"to run this problem.");
|
||||
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
|
||||
values_in, values_out, num_key_value_pairs, 0,
|
||||
num_bits_, stream);
|
||||
}
|
||||
// CubKeyValueSorter definition end
|
||||
|
||||
static inline size_t pad_to_multiple_of_16(size_t const& input) {
|
||||
static constexpr int ALIGNMENT = 16;
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
template <class T>
|
||||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
|
||||
int64_t const arr_length,
|
||||
T const target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high) {
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] >= target) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
// Calculates the start offset of the tokens for a given expert. The last
|
||||
// element is the total number of valid tokens
|
||||
__global__ void computeExpertFirstTokenOffsetKernel(
|
||||
int const* sorted_experts, int64_t const sorted_experts_len,
|
||||
int const num_experts, int64_t* expert_first_token_offset) {
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Note that expert goes [0, num_experts] (inclusive) because we want a count
|
||||
// for the total number of active tokens at the end of the scan.
|
||||
if (expert >= num_experts + 1) {
|
||||
return;
|
||||
}
|
||||
expert_first_token_offset[expert] =
|
||||
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream) {
|
||||
int const num_entries = num_experts + 1;
|
||||
int const threads = std::min(1024, num_entries);
|
||||
int const blocks = (num_entries + threads - 1) / threads;
|
||||
|
||||
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, expert_first_token_offset);
|
||||
}
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream) {
|
||||
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
|
||||
// We need to use the full num_experts because that is the sentinel value used
|
||||
// by topk for disabled experts
|
||||
sorter.updateNumExperts(num_experts);
|
||||
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
|
||||
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
|
||||
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
|
||||
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
|
||||
stream);
|
||||
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
|
||||
num_experts_per_node, expert_first_token_offset,
|
||||
stream);
|
||||
}
|
||||
|
||||
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr,
|
||||
int num_experts) {
|
||||
auto tidx = threadIdx.x;
|
||||
auto bidx = blockIdx.x;
|
||||
auto lidx = tidx & 31;
|
||||
auto widx = tidx >> 5;
|
||||
auto warp_count = (blockDim.x + 31) >> 5;
|
||||
auto offset = bidx * blockDim.x;
|
||||
auto bound = min(offset + blockDim.x, size);
|
||||
extern __shared__ int smem_expert_map[];
|
||||
// store expert_map in smem
|
||||
for (int i = tidx; i < num_experts; i += blockDim.x) {
|
||||
smem_expert_map[i] = expert_map_ptr[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// query global expert id in expert map.
|
||||
// if global expert id = -1 in exert map, plus n_expert
|
||||
// else set global expert id = exert map[global expert id]
|
||||
if (offset + tidx < bound) {
|
||||
auto topk_id = topk_id_ptr[offset + tidx];
|
||||
auto local_expert_idx = smem_expert_map[topk_id];
|
||||
if (local_expert_idx == -1) {
|
||||
topk_id += num_experts;
|
||||
} else {
|
||||
topk_id = local_expert_idx;
|
||||
}
|
||||
__syncwarp();
|
||||
topk_id_ptr[offset + tidx] = topk_id;
|
||||
}
|
||||
}
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream) {
|
||||
int block = std::min(size, 1024);
|
||||
int grid = (size + block - 1) / block;
|
||||
int smem_size = (num_experts) * sizeof(int);
|
||||
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
|
||||
topk_id_ptr, size, expert_map_ptr, num_experts);
|
||||
}
|
||||
|
||||
template <bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset,
|
||||
int* m_indices, const int num_local_expert,
|
||||
const int align_block_size) {
|
||||
int eidx = blockIdx.x;
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
auto first_token_offset = smem_expert_first_token_offset[eidx];
|
||||
int n_token_in_expert = last_token_offset - first_token_offset;
|
||||
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
// round up to ALIGN_BLOCK_SIZE
|
||||
int64_t accumulate_align_offset = 0;
|
||||
for (int i = 1; i <= eidx + 1; i++) {
|
||||
int n_token = smem_expert_first_token_offset[i] -
|
||||
smem_expert_first_token_offset[i - 1];
|
||||
accumulate_align_offset =
|
||||
accumulate_align_offset + (n_token + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
if (i == eidx) {
|
||||
first_token_offset = accumulate_align_offset;
|
||||
}
|
||||
// last block store align_expert_first_token_offset
|
||||
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
|
||||
align_expert_first_token_offset[i] = accumulate_align_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
|
||||
// update m_indice with expert id
|
||||
m_indices[first_token_offset + idx] = eidx;
|
||||
}
|
||||
}
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream) {
|
||||
int block = 256;
|
||||
int grid = num_local_expert;
|
||||
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
|
||||
if (align_block_size == -1) {
|
||||
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
} else {
|
||||
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
// reference from tensorrt_llm moe kernel implementation archive in
|
||||
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include "dispatch.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
template <typename T>
|
||||
inline T* get_ptr(torch::Tensor& t) {
|
||||
return reinterpret_cast<T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T* get_ptr(const torch::Tensor& t) {
|
||||
return reinterpret_cast<const T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
class CubKeyValueSorter {
|
||||
public:
|
||||
CubKeyValueSorter();
|
||||
|
||||
CubKeyValueSorter(int const num_experts);
|
||||
|
||||
void updateNumExperts(int const num_experts);
|
||||
|
||||
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts);
|
||||
|
||||
void run(void* workspace, size_t const workspace_size, int const* keys_in,
|
||||
int* keys_out, int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
static int expertsToBits(int experts);
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream);
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream);
|
||||
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
#include "moe_permute_unpermute_kernel.inl"
|
||||
@ -0,0 +1,211 @@
|
||||
#pragma once
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
int64_t expanded_dest_row = blockIdx.x;
|
||||
int64_t const expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
int64_t align_expanded_row_accumulate = 0;
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
// load g2s
|
||||
for (int idx = threadIdx.x; idx < num_local_experts + 1;
|
||||
idx += blockDim.x) {
|
||||
smem_expert_first_token_offset[idx] =
|
||||
__ldg(expert_first_token_offset + idx);
|
||||
}
|
||||
__syncthreads();
|
||||
int lane_idx = threadIdx.x & 31;
|
||||
|
||||
if (lane_idx == 0) {
|
||||
// set token_offset_in_expert = 0 if this expert is not local expert
|
||||
int token_offset_in_expert =
|
||||
expert_id >= num_local_experts
|
||||
? 0
|
||||
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
|
||||
int64_t accumulate_align_offset = 0;
|
||||
#pragma unroll 1
|
||||
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
|
||||
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
|
||||
smem_expert_first_token_offset[eidx - 1];
|
||||
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
}
|
||||
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
|
||||
}
|
||||
// lane0 shuffle broadcast align_expanded_dest_row
|
||||
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
static_cast<int>(expanded_dest_row);
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
int64_t const source_k_rank = expanded_source_row / num_rows;
|
||||
int64_t const source_row = expanded_source_row % num_rows;
|
||||
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
|
||||
auto* dest_row_ptr =
|
||||
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
const float* unpermuted_scales, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows * k;
|
||||
int64_t const threads = 256;
|
||||
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
|
||||
FuncPtr func_map[2][2] = {
|
||||
{&expandInputRowsKernel<T, false, false>,
|
||||
&expandInputRowsKernel<T, false, true>},
|
||||
{&expandInputRowsKernel<T, true, false>,
|
||||
&expandInputRowsKernel<T, true, true>},
|
||||
};
|
||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||
bool is_align_block_size = align_block_size != -1;
|
||||
auto func = func_map[is_check_skip][is_align_block_size];
|
||||
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
|
||||
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
|
||||
align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
__host__ __device__ constexpr static U arrayConvert(T const& input) {
|
||||
using Type = typename U::Element;
|
||||
static_assert(T::kElements == U::kElements);
|
||||
U u;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < U::kElements; i++) {
|
||||
u[i] = static_cast<Type>(input[i]);
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr) {
|
||||
assert(orig_cols % 4 == 0);
|
||||
int64_t const original_row = blockIdx.x;
|
||||
int64_t const num_rows = gridDim.x;
|
||||
auto const offset = original_row * orig_cols;
|
||||
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
int64_t const num_valid = *num_valid_ptr;
|
||||
|
||||
// Load 128-bits per thread, according to the smallest data type we read/write
|
||||
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
|
||||
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
|
||||
cutlass::sizeof_bits<T>::value);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
|
||||
|
||||
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
|
||||
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||
auto const* expanded_permuted_rows_v =
|
||||
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
|
||||
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
|
||||
|
||||
#pragma unroll
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = scales[k_offset];
|
||||
|
||||
// Check after row_rescale has accumulated
|
||||
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto const* expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
|
||||
|
||||
int64_t const expert_idx = expert_for_source_row[k_offset];
|
||||
|
||||
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
|
||||
expanded_permuted_rows_row_ptr[elem_index]);
|
||||
thread_output = thread_output + row_scale * (expert_result);
|
||||
}
|
||||
|
||||
OutputElem output_elem =
|
||||
arrayConvert<ComputeElem, OutputElem>(thread_output);
|
||||
reduced_row_ptr_v[elem_index] = output_elem;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
|
||||
cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = 256;
|
||||
bool const check_finished = num_valid_ptr != nullptr;
|
||||
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
|
||||
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
|
||||
&finalizeMoeRoutingKernel<T, OutputType, true>};
|
||||
auto* const kernel = func_map[check_finished];
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, scales,
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
|
||||
num_valid_ptr);
|
||||
}
|
||||
@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
|
||||
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
|
||||
"m_indices)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
|
||||
"expert_first_token_offset, int n_expert, int n_local_expert,int "
|
||||
"topk, Tensor! hidden_states)->()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
@ -420,7 +420,8 @@ def test_fused_marlin_moe(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, False)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
|
||||
|
||||
223
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
223
tests/kernels/moe/test_moe_permute_unpermute.py
Normal file
@ -0,0 +1,223 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the MOE permute/unpermute kernel
|
||||
|
||||
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
moe_permute, moe_unpermute)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [16, 64]
|
||||
TOP_KS = [2, 4, 6, 8]
|
||||
EP_SIZE = [1, 4, 16]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
|
||||
def torch_permute(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
start_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
|
||||
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
|
||||
if expert_map is not None:
|
||||
is_local_expert = (expert_map[topk_ids] != -1)
|
||||
not_local_expert = (expert_map[topk_ids] == -1)
|
||||
topk_ids = is_local_expert * (
|
||||
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
|
||||
|
||||
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
|
||||
stable=True)
|
||||
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
|
||||
|
||||
expert_first_token_offset = torch.zeros(n_local_expert + 1,
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
idx = 0
|
||||
for i in range(0, n_local_expert):
|
||||
cnt = 0
|
||||
while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
|
||||
cnt += 1
|
||||
idx += 1
|
||||
expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt
|
||||
|
||||
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
|
||||
valid_row_idx = []
|
||||
if align_block_size is None:
|
||||
|
||||
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map %
|
||||
n_token, ...]
|
||||
permuted_row_size = permuted_hidden_states.shape[0]
|
||||
m_indices = torch.empty(permuted_row_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32).fill_(fill_invalid_expert)
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
m_indices[first_token_offset:last_token_offset] = i - 1
|
||||
src_row_id2dst_row_id_map = torch.arange(
|
||||
0, n_token * topk, device="cuda",
|
||||
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
|
||||
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
|
||||
return [
|
||||
permuted_hidden_states, expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices, valid_row_idx
|
||||
]
|
||||
else:
|
||||
permuted_row_size = (topk * n_token + n_expert *
|
||||
(align_block_size - 1) + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
|
||||
device="cuda",
|
||||
dtype=hidden_states.dtype)
|
||||
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
align_expert_first_token_offset = torch.zeros_like(
|
||||
expert_first_token_offset)
|
||||
m_indices = torch.empty(permuted_row_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32).fill_(fill_invalid_expert)
|
||||
# get align_permuted_hidden_states,
|
||||
# valid row_idx and align_expert_first_token_offset
|
||||
for i in range(1, n_local_expert + 1):
|
||||
first_token_offset = expert_first_token_offset[i - 1]
|
||||
last_token_offset = expert_first_token_offset[i]
|
||||
n_token_in_expert = last_token_offset - first_token_offset
|
||||
align_expert_first_token_offset[
|
||||
i] = align_expert_first_token_offset[
|
||||
i - 1] + (n_token_in_expert + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
align_first_token_offset = align_expert_first_token_offset[i - 1]
|
||||
align_last_token_offset = align_expert_first_token_offset[i]
|
||||
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
|
||||
first_token_offset:first_token_offset +
|
||||
n_token_in_expert] % n_token
|
||||
# store token in current expert with align_first_token_offset
|
||||
permuted_hidden_states[align_first_token_offset:\
|
||||
align_first_token_offset+n_token_in_expert,\
|
||||
...] = hidden_states[\
|
||||
dst_row_id2src_row_id_in_expert, ...]
|
||||
# set current expert m_indices
|
||||
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
|
||||
valid_row_idx += [
|
||||
i for i in range(align_first_token_offset,
|
||||
align_first_token_offset + n_token_in_expert)
|
||||
]
|
||||
# get align_src_row_id2dst_row_id
|
||||
for i in range(n_token * topk):
|
||||
eid = sorted_topk_ids[i]
|
||||
if (eid >= n_local_expert):
|
||||
# check token not in local expert
|
||||
align_src_row_id2dst_row_id[
|
||||
i] = align_expert_first_token_offset[-1]
|
||||
continue
|
||||
first_token_offset = expert_first_token_offset[eid]
|
||||
align_first_token_offset = align_expert_first_token_offset[eid]
|
||||
token_offset = i - first_token_offset
|
||||
align_src_row_id2dst_row_id[
|
||||
i] = align_first_token_offset + token_offset
|
||||
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
|
||||
src2dst_idx].reshape((n_token, topk))
|
||||
return [
|
||||
permuted_hidden_states, align_expert_first_token_offset,
|
||||
align_src_row_id2dst_row_id, m_indices, valid_row_idx
|
||||
]
|
||||
|
||||
|
||||
def torch_unpermute(permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
valid_row_idx: torch.Tensor, topk: int,
|
||||
n_expert: int) -> torch.Tensor:
|
||||
# ignore invalid row
|
||||
mask = torch.zeros(permuted_hidden_states.shape[0],
|
||||
dtype=bool,
|
||||
device="cuda")
|
||||
mask[valid_row_idx] = True
|
||||
permuted_hidden_states[~mask] = 0
|
||||
idx = src_row_id2dst_row_id_map.flatten()[
|
||||
token_expert_indices.flatten()].reshape(token_expert_indices.shape)
|
||||
output = permuted_hidden_states[idx, ...] * topk_weights[..., None]
|
||||
output = output.sum(dim=1).to(permuted_hidden_states.dtype)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000])
|
||||
@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168])
|
||||
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("align_block_size", [None, 128])
|
||||
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
n_expert: int, ep_size: int, dtype: torch.dtype,
|
||||
align_block_size: Optional[int]):
|
||||
fill_invalid_expert = 0
|
||||
ep_rank = np.random.randint(0, ep_size)
|
||||
expert_map = None
|
||||
n_local_expert = n_expert
|
||||
if (ep_size != 1):
|
||||
n_local_expert, expert_map = determine_expert_map(
|
||||
ep_size, ep_rank, n_expert)
|
||||
expert_map = expert_map.cuda()
|
||||
start_expert = n_local_expert * ep_rank
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
|
||||
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, False)
|
||||
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
topk,
|
||||
n_expert,
|
||||
n_local_expert,
|
||||
start_expert,
|
||||
expert_map=expert_map,
|
||||
align_block_size=align_block_size,
|
||||
fill_invalid_expert=fill_invalid_expert)
|
||||
|
||||
result0, result1, result2, result3 = moe_permute(
|
||||
hidden_states, topk_weights, topk_ids, token_expert_indices, topk,
|
||||
n_expert, n_local_expert, expert_map, align_block_size,
|
||||
fill_invalid_expert)
|
||||
|
||||
# check expert_first_token_offset
|
||||
torch.testing.assert_close(gold1, result1, atol=0, rtol=0)
|
||||
# check src_row_id2dst_row_id_map
|
||||
torch.testing.assert_close(gold2, result2, atol=0, rtol=0)
|
||||
# check mindice
|
||||
torch.testing.assert_close(gold3, result3, atol=0, rtol=0)
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold0[valid_row_idx],
|
||||
result0[valid_row_idx],
|
||||
atol=0,
|
||||
rtol=0)
|
||||
|
||||
# add a random tensor to simulate group gemm
|
||||
result0 = 0.5 * result0 + torch.randn_like(result0)
|
||||
|
||||
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
|
||||
topk, n_expert, n_local_expert)
|
||||
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
|
||||
token_expert_indices, result2, valid_row_idx, topk,
|
||||
n_local_expert)
|
||||
|
||||
# check unpermuted hidden
|
||||
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)
|
||||
@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, False)
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
|
||||
@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
M, K = a.shape
|
||||
N = w2.shape[-1]
|
||||
|
||||
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
topk_weight, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
|
||||
@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
|
||||
@ -71,8 +71,8 @@ def single_marlin_moe(
|
||||
E = w.shape[0]
|
||||
N = w.shape[2] // (num_bits // 2)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
# This might not be an optimal config for a single MMM
|
||||
get_config_func = functools.partial(try_get_optimal_moe_config,
|
||||
|
||||
@ -854,7 +854,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@ -868,20 +868,19 @@ def fused_topk(
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indices = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
||||
token_expert_indicies,
|
||||
token_expert_indices,
|
||||
gating_output_float, renormalize)
|
||||
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@ -1510,8 +1509,8 @@ def fused_moe(
|
||||
topk, renormalize,
|
||||
num_expert_group, topk_group)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
@ -801,10 +801,11 @@ class FusedMoE(torch.nn.Module):
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
116
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
Normal file
116
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
Normal file
@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def moe_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
align_block_size: Optional[int] = None,
|
||||
fill_invalid_expert: int = -1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function expands and permutes activation to gather uncontinuous tokens
|
||||
for each expert.
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
||||
- topk_ids (torch.Tensor): topk expert route id for each token.
|
||||
- token_expert_indices (torch.Tensor): indice for expanded hidden.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- n_expert (int): The number of expert.
|
||||
- n_local_expert (int): The number of expert in current EP rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||
to workaround DeepGemm unsupported -1 in m_indices
|
||||
Returns:
|
||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||
expert_first_token_offset will align up to 'align_block_size'.
|
||||
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
|
||||
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
|
||||
the group which the j-th row of the LHS belong to.`
|
||||
"""
|
||||
n_token, n_hidden = hidden_states.shape
|
||||
assert (n_hidden * hidden_states.element_size()
|
||||
) % 16 == 0, "permue kernel need hidden dim align to 16B"
|
||||
permuted_row_size = n_token * topk
|
||||
if align_block_size is not None:
|
||||
permuted_row_size = (permuted_row_size + n_expert *
|
||||
(align_block_size - 1) + align_block_size -
|
||||
1) // align_block_size * align_block_size
|
||||
|
||||
permuted_hidden_states = torch.empty(
|
||||
(permuted_row_size, n_hidden),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
m_indices = torch.full((permuted_row_size, ),
|
||||
fill_invalid_expert,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
expert_first_token_offset = torch.empty(n_local_expert + 1,
|
||||
dtype=torch.int64,
|
||||
device=hidden_states.device)
|
||||
src_row_id2dst_row_id_map = torch.empty((n_token, topk),
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids,
|
||||
token_expert_indices, expert_map, n_expert,
|
||||
n_local_expert, topk, align_block_size,
|
||||
permuted_hidden_states,
|
||||
expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices)
|
||||
return (permuted_hidden_states, expert_first_token_offset,
|
||||
src_row_id2dst_row_id_map, m_indices)
|
||||
|
||||
|
||||
def moe_unpermute(
|
||||
permuted_hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
src_row_id2dst_row_id_map: torch.Tensor,
|
||||
expert_first_token_offset: torch.Tensor,
|
||||
topk: int,
|
||||
n_expert: int,
|
||||
n_local_expert: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function expands and permutes activation to gathering uncontinuous
|
||||
tokens for each expert.
|
||||
Parameters:
|
||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||
- topk_weights (torch.Tensor): topk expert route weight for each token.
|
||||
- topk_ids (torch.Tensor): topk expert route id for each token.
|
||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||
of each expert for grouped gemm.
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- n_expert (int): The number of expert.
|
||||
- n_local_expert (int): The number of expert in current EP rank.
|
||||
Returns:
|
||||
- hidden_states (torch.Tensor): The reduced and unpermuted activation
|
||||
tensor.
|
||||
"""
|
||||
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
|
||||
assert (n_hidden * permuted_hidden_states.element_size()
|
||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||
hidden_states = torch.empty((n_token, n_hidden),
|
||||
dtype=permuted_hidden_states.dtype,
|
||||
device=permuted_hidden_states.device)
|
||||
|
||||
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
|
||||
topk_ids, src_row_id2dst_row_id_map,
|
||||
expert_first_token_offset, n_expert,
|
||||
n_local_expert, topk, hidden_states)
|
||||
return hidden_states
|
||||
@ -175,10 +175,8 @@ class ArcticMoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
do_normalize = self.top_k > 1
|
||||
topk_weights, topk_ids = fused_topk(hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=do_normalize)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, router_logits, self.top_k, renormalize=do_normalize)
|
||||
# topk_ids: (num_tokens, k)
|
||||
if self.is_quant:
|
||||
if 2 * num_tokens <= self.num_experts:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user