From 3e887d2e0c1fcd65efbfe02db8a824c761fe4d41 Mon Sep 17 00:00:00 2001 From: Caleb_Du <59528230+CalebDu@users.noreply.github.com> Date: Sat, 3 May 2025 02:31:55 +0800 Subject: [PATCH] permute/unpermute kernel for moe optimization (#14568) Signed-off-by: Caleb_Du --- CMakeLists.txt | 14 +- .../kernels/benchmark_grouped_gemm_cutlass.py | 3 +- benchmarks/kernels/benchmark_moe.py | 4 +- .../benchmark_moe_permute_unpermute.py | 349 ++++++++++++++++++ csrc/moe/moe_permute_unpermute_op.cu | 133 +++++++ csrc/moe/permute_unpermute_kernels/dispatch.h | 53 +++ .../moe_permute_unpermute_kernel.cu | 229 ++++++++++++ .../moe_permute_unpermute_kernel.h | 95 +++++ .../moe_permute_unpermute_kernel.inl | 211 +++++++++++ csrc/moe/torch_bindings.cpp | 22 ++ tests/kernels/moe/test_moe.py | 3 +- .../kernels/moe/test_moe_permute_unpermute.py | 223 +++++++++++ tests/kernels/quantization/test_awq_marlin.py | 3 +- tests/kernels/quantization/test_block_fp8.py | 6 +- .../layers/fused_moe/fused_marlin_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 19 +- vllm/model_executor/layers/fused_moe/layer.py | 9 +- .../layers/fused_moe/moe_permute_unpermute.py | 116 ++++++ vllm/model_executor/models/arctic.py | 6 +- 19 files changed, 1474 insertions(+), 28 deletions(-) create mode 100644 benchmarks/kernels/benchmark_moe_permute_unpermute.py create mode 100644 csrc/moe/moe_permute_unpermute_op.cu create mode 100644 csrc/moe/permute_unpermute_kernels/dispatch.h create mode 100644 csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu create mode 100644 csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h create mode 100644 csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl create mode 100644 tests/kernels/moe/test_moe_permute_unpermute.py create mode 100644 vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 72740279d0e05..be84c81295568 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") - message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") @@ -682,6 +681,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_PERMUTE_SRC}" + CUDA_ARCHS "${MOE_PERMUTE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") +endif() message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C @@ -690,6 +700,8 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index bcdbf6c7551a3..c92ea43e8260d 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str, score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, renormalize=False) def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a274537a67515..c34f97dec8ea3 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig, from vllm.model_executor.layers.fused_moe import override_config with override_config(config): if use_deep_gemm: - topk_weights, topk_ids = fused_topk(x, input_gating, topk, - False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, False) return fused_experts( x, w1, diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py new file mode 100644 index 0000000000000..937df96246514 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from typing import Any, TypedDict + +import ray +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _moe_permute, _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +FP8_DTYPE = current_platform.fp8_dtype() + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_permute(num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + # output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + gating_output = torch.randn(num_iters, + num_tokens, + num_experts, + dtype=torch.float32) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + else: + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, + num_experts, None, align_block_size) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def benchmark_unpermute(num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False) + + def prepare(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + # convert to fp16/bf16 as gemm output + return (permuted_hidden_states.to(dtype), first_token_off, + inv_perm_idx, m_indices) + else: + (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = _moe_permute(qhidden_states, None, topk_ids, + num_experts, None, align_block_size) + # convert to fp16/bf16 as gemm output + return (permuted_qhidden_states.to(dtype), a1q_scale, + sorted_token_ids, expert_ids, inv_perm) + + def run(input: tuple): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, + m_indices) = input + moe_unpermute(permuted_hidden_states, topk_weights, topk_ids, + inv_perm_idx, first_token_off, topk, num_experts, + num_experts) + else: + (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) = input + _moe_unpermute_and_reduce(output_hidden_states, + permuted_hidden_states, inv_perm, + topk_weights) + + # JIT compilation & warmup + input = prepare() + run(input) + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run(input) + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(seed) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_customized_permute: bool = False, + ) -> tuple[dict[str, int], float]: + current_platform.seed_everything(self.seed) + + permute_time = benchmark_permute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute) + unpermute_time = benchmark_unpermute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute) + return permute_time, unpermute_time + + +def get_weight_block_size_safety(config, default_value=None): + + quantization_config = getattr(config, 'quantization_config', {}) + if isinstance(quantization_config, dict): + return quantization_config.get('weight_block_size', default_value) + return default_value + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + elif (config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM"): + E = config.n_routed_experts + topk = config.num_experts_per_tok + elif config.architectures[0] in [ + "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" + ]: + E = config.num_experts + topk = config.num_experts_per_tok + + else: + # Support for llama4 + config = config.get_text_config() + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + + hidden_size = config.hidden_size + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + use_customized_permute = args.use_customized_permute + + if args.batch_size is None: + batch_sizes = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + 2048, 3072, 4096 + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: list[Any]) -> list[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + outputs = _distribute( + "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8, + use_int8_w8a16, use_customized_permute) + for batch_size in batch_sizes]) + + for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}") + print(f"Permute time: {permute:.2f} us") + print(f"Unpermute time: {unpermute:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1") + parser.add_argument("--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16"], + default="auto") + parser.add_argument("--use-customized-permute", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu new file mode 100644 index 0000000000000..76d5f0eab0218 --- /dev/null +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -0,0 +1,133 @@ +#include +#include +#include +#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h" +#include "permute_unpermute_kernels/dispatch.h" +#include "core/registration.h" + +void moe_permute( + const torch::Tensor& input, // [n_token, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& token_expert_indicies, // [n_token, topk] + const std::optional& expert_map, // [n_expert] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& + permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] + torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& m_indices) { // [align_expand_m] + TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, + "topk_weights must be float32"); + TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, + "expert_first_token_offset must be int64"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, + "token_expert_indicies must be int32"); + TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, + "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, + "expert_first_token_offset shape != n_local_expert+1") + TORCH_CHECK( + src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), + "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); + auto n_token = input.sizes()[0]; + auto n_hidden = input.sizes()[1]; + auto align_block_size_value = + align_block_size.has_value() ? align_block_size.value() : -1; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const long sorter_size = + CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); + auto sort_workspace = torch::empty( + {sorter_size}, + torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto permuted_experts_id = torch::empty_like(topk_ids); + auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + + CubKeyValueSorter sorter{}; + int64_t* valid_num_ptr = nullptr; + // pre-process kernel for expert-parallelism: + // no local expert id plus "n_expert" offset for priority to local expert + // map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1] + // For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id + // [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids + // and map global expert id [2, 3] to local_expert id [0, 1] and map global + // expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map + // operation is to make local expert high priority in following sort topk_ids + // and scan local expert_first_token_offset for each ep rank for next group + // gemm. + if (expert_map.has_value()) { + const int* expert_map_ptr = get_ptr(expert_map.value()); + valid_num_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + expert_map_ptr, n_expert, stream); + } + // expert sort topk expert id and scan expert id get expert_first_token_offset + sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indicies), + get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(expert_first_token_offset), n_token, + n_expert, n_local_expert, topk, sorter, + get_ptr(sort_workspace), stream); + + // dispatch expandInputRowsKernelLauncher + MOE_DISPATCH(input.scalar_type(), [&] { + expandInputRowsKernelLauncher( + get_ptr(input), get_ptr(permuted_input), + get_ptr(topk_weights), get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(src_row_id2dst_row_id_map), + get_ptr(expert_first_token_offset), n_token, valid_num_ptr, + n_hidden, topk, n_local_expert, align_block_size_value, stream); + }); + + // get m_indices and update expert_first_token_offset with align block + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); + if (align_block_size.has_value()) { + // update align_expert_first_token_offset + expert_first_token_offset.copy_(align_expert_first_token_offset); + } +} + +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + torch::Tensor& hidden_states // [n_token, hidden] +) { + TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), + "topk_ids shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK( + permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), + "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + auto n_token = hidden_states.size(0); + auto n_hidden = hidden_states.size(1); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int64_t* valid_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + MOE_DISPATCH(hidden_states.scalar_type(), [&] { + finalizeMoeRoutingKernelLauncher( + get_ptr(permuted_hidden_states), + get_ptr(hidden_states), get_ptr(topk_weights), + get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), + n_token, n_hidden, topk, valid_ptr, stream); + }); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_permute", &moe_permute); + m.impl("moe_unpermute", &moe_unpermute); +} \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h new file mode 100644 index 0000000000000..41932cdd85bcd --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -0,0 +1,53 @@ +#pragma once +#include +#define MOE_SWITCH(TYPE, ...) \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \ + } + +#define MOE_DISPATCH_CASE(enum_type, ...) \ + case enum_type: { \ + using scalar_t = ScalarType2CudaType::type; \ + __VA_ARGS__(); \ + break; \ + } +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + +#define MOE_DISPATCH(TYPE, ...) \ + MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) + +template +struct ScalarType2CudaType; + +template <> +struct ScalarType2CudaType { + using type = float; +}; +template <> +struct ScalarType2CudaType { + using type = half; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_bfloat16; +}; + +// #if __CUDA_ARCH__ >= 890 +// fp8 +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e5m2; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e4m3; +}; +// #endif \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu new file mode 100644 index 0000000000000..aa353d0f0437f --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -0,0 +1,229 @@ + +#include "moe_permute_unpermute_kernel.h" + +// CubKeyValueSorter definition begin +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +int CubKeyValueSorter::expertsToBits(int num_experts) { + // Max value we represent is V = num_experts + (num_experts - 1) = 2 * + // num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1 + return static_cast(log2(2 * num_experts - 1)) + 1; +} + +CubKeyValueSorter::CubKeyValueSorter(int const num_experts) + : num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {} + +void CubKeyValueSorter::updateNumExperts(int const num_experts) { + num_experts_ = num_experts; + num_bits_ = expertsToBits(num_experts); +} + +size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts) { + int num_bits = expertsToBits(num_experts); + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int, + null_int, null_int, num_key_value_pairs, 0, + num_bits); + + // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, + // 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same + // inputs + if (required_storage == 0) { + required_storage = 1; + } + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, + int const* keys_in, int* keys_out, + int const* values_in, int* values_out, + size_t const num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); + size_t actual_ws_size = workspace_size; + + TORCH_CHECK(expected_ws_size <= workspace_size, + "[CubKeyValueSorter::run] The allocated workspace is too small " + "to run this problem."); + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, + values_in, values_out, num_key_value_pairs, 0, + num_bits_, stream); +} +// CubKeyValueSorter definition end + +static inline size_t pad_to_multiple_of_16(size_t const& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, + int64_t const arr_length, + T const target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] >= target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Calculates the start offset of the tokens for a given expert. The last +// element is the total number of valid tokens +__global__ void computeExpertFirstTokenOffsetKernel( + int const* sorted_experts, int64_t const sorted_experts_len, + int const num_experts, int64_t* expert_first_token_offset) { + // First, compute the global tid. We only need 1 thread per expert. + int const expert = blockIdx.x * blockDim.x + threadIdx.x; + + // Note that expert goes [0, num_experts] (inclusive) because we want a count + // for the total number of active tokens at the end of the scan. + if (expert >= num_experts + 1) { + return; + } + expert_first_token_offset[expert] = + findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert); +} + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream) { + int const num_entries = num_experts + 1; + int const threads = std::min(1024, num_entries); + int const blocks = (num_entries + threads - 1) / threads; + + computeExpertFirstTokenOffsetKernel<<>>( + sorted_indices, total_indices, num_experts, expert_first_token_offset); +} + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream) { + int64_t const expanded_num_rows = static_cast(k) * num_rows; + // We need to use the full num_experts because that is the sentinel value used + // by topk for disabled experts + sorter.updateNumExperts(num_experts); + size_t const sorter_ws_size_bytes = pad_to_multiple_of_16( + sorter.getWorkspaceSize(expanded_num_rows, num_experts)); + sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row, + permuted_experts, source_rows, permuted_rows, expanded_num_rows, + stream); + computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows, + num_experts_per_node, expert_first_token_offset, + stream); +} + +__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, + const int* expert_map_ptr, + int num_experts) { + auto tidx = threadIdx.x; + auto bidx = blockIdx.x; + auto lidx = tidx & 31; + auto widx = tidx >> 5; + auto warp_count = (blockDim.x + 31) >> 5; + auto offset = bidx * blockDim.x; + auto bound = min(offset + blockDim.x, size); + extern __shared__ int smem_expert_map[]; + // store expert_map in smem + for (int i = tidx; i < num_experts; i += blockDim.x) { + smem_expert_map[i] = expert_map_ptr[i]; + } + __syncthreads(); + + // query global expert id in expert map. + // if global expert id = -1 in exert map, plus n_expert + // else set global expert id = exert map[global expert id] + if (offset + tidx < bound) { + auto topk_id = topk_id_ptr[offset + tidx]; + auto local_expert_idx = smem_expert_map[topk_id]; + if (local_expert_idx == -1) { + topk_id += num_experts; + } else { + topk_id = local_expert_idx; + } + __syncwarp(); + topk_id_ptr[offset + tidx] = topk_id; + } +} +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream) { + int block = std::min(size, 1024); + int grid = (size + block - 1) / block; + int smem_size = (num_experts) * sizeof(int); + preprocessTopkIdKernel<<>>( + topk_id_ptr, size, expert_map_ptr, num_experts); +} + +template +__global__ void getMIndicesKernel(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, + int* m_indices, const int num_local_expert, + const int align_block_size) { + int eidx = blockIdx.x; + int tidx = threadIdx.x; + extern __shared__ int64_t smem_expert_first_token_offset[]; + for (int i = tidx; i <= num_local_expert; i += blockDim.x) { + smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + } + __syncthreads(); + auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; + auto first_token_offset = smem_expert_first_token_offset[eidx]; + int n_token_in_expert = last_token_offset - first_token_offset; + + if constexpr (ALIGN_BLOCK_SIZE) { + n_token_in_expert = (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + // round up to ALIGN_BLOCK_SIZE + int64_t accumulate_align_offset = 0; + for (int i = 1; i <= eidx + 1; i++) { + int n_token = smem_expert_first_token_offset[i] - + smem_expert_first_token_offset[i - 1]; + accumulate_align_offset = + accumulate_align_offset + (n_token + align_block_size - 1) / + align_block_size * align_block_size; + if (i == eidx) { + first_token_offset = accumulate_align_offset; + } + // last block store align_expert_first_token_offset + if (eidx == num_local_expert - 1 && threadIdx.x == 0) { + align_expert_first_token_offset[i] = accumulate_align_offset; + } + } + } + for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) { + // update m_indice with expert id + m_indices[first_token_offset + idx] = eidx; + } +} + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream) { + int block = 256; + int grid = num_local_expert; + int smem_size = sizeof(int64_t) * (num_local_expert + 1); + if (align_block_size == -1) { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } else { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } +} \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h new file mode 100644 index 0000000000000..43c29721cd16e --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -0,0 +1,95 @@ +#pragma once +// reference from tensorrt_llm moe kernel implementation archive in +// https://github.com/BBuf/tensorrt-llm-moe/tree/master + +#include +#include +#include "dispatch.h" +#include +#include +#include +#include "cutlass/numeric_size.h" +#include "cutlass/array.h" + +template +inline T* get_ptr(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +template +inline const T* get_ptr(const torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts); + + void updateNumExperts(int const num_experts); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, + int* keys_out, int const* values_in, int* values_out, + size_t const num_key_value_pairs, cudaStream_t stream); + + private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream); + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream); + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream); + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and +// performs the final skip connection. +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr); + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream); + +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream); + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream); + +#include "moe_permute_unpermute_kernel.inl" diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl new file mode 100644 index 0000000000000..42441800fb110 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -0,0 +1,211 @@ +#pragma once + +template +__global__ void expandInputRowsKernel( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_dest_rows, int64_t const cols, int64_t k, + int num_local_experts, int align_block_size) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + int64_t expanded_dest_row = blockIdx.x; + int64_t const expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + int expert_id = sorted_experts[expanded_dest_row]; + + extern __shared__ int64_t smem_expert_first_token_offset[]; + int64_t align_expanded_row_accumulate = 0; + if constexpr (ALIGN_BLOCK_SIZE) { + // load g2s + for (int idx = threadIdx.x; idx < num_local_experts + 1; + idx += blockDim.x) { + smem_expert_first_token_offset[idx] = + __ldg(expert_first_token_offset + idx); + } + __syncthreads(); + int lane_idx = threadIdx.x & 31; + + if (lane_idx == 0) { + // set token_offset_in_expert = 0 if this expert is not local expert + int token_offset_in_expert = + expert_id >= num_local_experts + ? 0 + : expanded_dest_row - smem_expert_first_token_offset[expert_id]; + int64_t accumulate_align_offset = 0; +#pragma unroll 1 + for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) { + auto n_token_in_expert = smem_expert_first_token_offset[eidx] - + smem_expert_first_token_offset[eidx - 1]; + accumulate_align_offset += (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + } + expanded_dest_row = accumulate_align_offset + token_offset_in_expert; + } + // lane0 shuffle broadcast align_expanded_dest_row + expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0); + } + + if (threadIdx.x == 0) { + assert(expanded_dest_row <= INT32_MAX); + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + static_cast(expanded_dest_row); + } + + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + int64_t const source_k_rank = expanded_source_row / num_rows; + int64_t const source_row = expanded_source_row % num_rows; + + auto const* source_row_ptr = + reinterpret_cast(unpermuted_input + source_row * cols); + auto* dest_row_ptr = + reinterpret_cast(permuted_output + expanded_dest_row * cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream) { + int64_t const blocks = num_rows * k; + int64_t const threads = 256; + using FuncPtr = decltype(&expandInputRowsKernel); + FuncPtr func_map[2][2] = { + {&expandInputRowsKernel, + &expandInputRowsKernel}, + {&expandInputRowsKernel, + &expandInputRowsKernel}, + }; + bool is_check_skip = num_valid_tokens_ptr != nullptr; + bool is_align_block_size = align_block_size != -1; + auto func = func_map[is_check_skip][is_align_block_size]; + + int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); + + func<<>>( + unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, expert_first_token_offset, + num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, + align_block_size); +} + +template +__host__ __device__ constexpr static U arrayConvert(T const& input) { + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; +#pragma unroll + for (int i = 0; i < U::kElements; i++) { + u[i] = static_cast(input[i]); + } + return u; +} + +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr) { + assert(orig_cols % 4 == 0); + int64_t const original_row = blockIdx.x; + int64_t const num_rows = gridDim.x; + auto const offset = original_row * orig_cols; + OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; + int64_t const num_valid = *num_valid_ptr; + + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t FINALIZE_ELEM_PER_THREAD = + 128 / std::min(cutlass::sizeof_bits::value, + cutlass::sizeof_bits::value); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + auto const* expanded_permuted_rows_v = + reinterpret_cast(expanded_permuted_rows); + auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); + +#pragma unroll + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + ComputeElem thread_output; + thread_output.fill(0); + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) { + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + int64_t const k_offset = original_row * k + k_idx; + float const row_scale = scales[k_offset]; + + // Check after row_rescale has accumulated + if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { + continue; + } + + auto const* expanded_permuted_rows_row_ptr = + expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; + + int64_t const expert_idx = expert_for_source_row[k_offset]; + + ComputeElem expert_result = arrayConvert( + expanded_permuted_rows_row_ptr[elem_index]); + thread_output = thread_output + row_scale * (expert_result); + } + + OutputElem output_elem = + arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; + } +} + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream) { + int64_t const blocks = num_rows; + int64_t const threads = 256; + bool const check_finished = num_valid_ptr != nullptr; + using FuncPtr = decltype(&finalizeMoeRoutingKernel); + FuncPtr func_map[2] = {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}; + auto* const kernel = func_map[check_finished]; + kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, + num_valid_ptr); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d0de42251f97a..2a8b9bb39caa9 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); + m.def( + "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," + "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," + "int n_local_expert," + "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " + "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " + "m_indices)->()"); + + m.def( + "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," + "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " + "expert_first_token_offset, int n_expert, int n_local_expert,int " + "topk, Tensor! hidden_states)->()"); // conditionally compiled so impl registration is in source file #endif diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 425f36984a33b..f2cca65ae4209 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -420,7 +420,8 @@ def test_fused_marlin_moe( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, False) torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py new file mode 100644 index 0000000000000..dfcd61f775870 --- /dev/null +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE permute/unpermute kernel + +Run `pytest tests/kernels/test_moe_permute_unpermute.py`. +""" + +from typing import Optional + +import numpy as np +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) +from vllm.platforms import current_platform + +NUM_EXPERTS = [16, 64] +TOP_KS = [2, 4, 6, 8] +EP_SIZE = [1, 4, 16] +current_platform.seed_everything(0) + + +def torch_permute(hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: + n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] + if expert_map is not None: + is_local_expert = (expert_map[topk_ids] != -1) + not_local_expert = (expert_map[topk_ids] == -1) + topk_ids = is_local_expert * ( + topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), + stable=True) + dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] + + expert_first_token_offset = torch.zeros(n_local_expert + 1, + dtype=torch.int64, + device="cuda") + idx = 0 + for i in range(0, n_local_expert): + cnt = 0 + while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i: + cnt += 1 + idx += 1 + expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt + + _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) + valid_row_idx = [] + if align_block_size is None: + + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % + n_token, ...] + permuted_row_size = permuted_hidden_states.shape[0] + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + m_indices[first_token_offset:last_token_offset] = i - 1 + src_row_id2dst_row_id_map = torch.arange( + 0, n_token * topk, device="cuda", + dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + return [ + permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices, valid_row_idx + ] + else: + permuted_row_size = (topk * n_token + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), + device="cuda", + dtype=hidden_states.dtype) + align_src_row_id2dst_row_id = torch.empty(n_token * topk, + device="cuda", + dtype=torch.int32) + align_expert_first_token_offset = torch.zeros_like( + expert_first_token_offset) + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + # get align_permuted_hidden_states, + # valid row_idx and align_expert_first_token_offset + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + n_token_in_expert = last_token_offset - first_token_offset + align_expert_first_token_offset[ + i] = align_expert_first_token_offset[ + i - 1] + (n_token_in_expert + align_block_size - + 1) // align_block_size * align_block_size + align_first_token_offset = align_expert_first_token_offset[i - 1] + align_last_token_offset = align_expert_first_token_offset[i] + dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ + first_token_offset:first_token_offset + + n_token_in_expert] % n_token + # store token in current expert with align_first_token_offset + permuted_hidden_states[align_first_token_offset:\ + align_first_token_offset+n_token_in_expert,\ + ...] = hidden_states[\ + dst_row_id2src_row_id_in_expert, ...] + # set current expert m_indices + m_indices[align_first_token_offset:align_last_token_offset] = i - 1 + valid_row_idx += [ + i for i in range(align_first_token_offset, + align_first_token_offset + n_token_in_expert) + ] + # get align_src_row_id2dst_row_id + for i in range(n_token * topk): + eid = sorted_topk_ids[i] + if (eid >= n_local_expert): + # check token not in local expert + align_src_row_id2dst_row_id[ + i] = align_expert_first_token_offset[-1] + continue + first_token_offset = expert_first_token_offset[eid] + align_first_token_offset = align_expert_first_token_offset[eid] + token_offset = i - first_token_offset + align_src_row_id2dst_row_id[ + i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ + src2dst_idx].reshape((n_token, topk)) + return [ + permuted_hidden_states, align_expert_first_token_offset, + align_src_row_id2dst_row_id, m_indices, valid_row_idx + ] + + +def torch_unpermute(permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, topk: int, + n_expert: int) -> torch.Tensor: + # ignore invalid row + mask = torch.zeros(permuted_hidden_states.shape[0], + dtype=bool, + device="cuda") + mask[valid_row_idx] = True + permuted_hidden_states[~mask] = 0 + idx = src_row_id2dst_row_id_map.flatten()[ + token_expert_indices.flatten()].reshape(token_expert_indices.shape) + output = permuted_hidden_states[idx, ...] * topk_weights[..., None] + output = output.sum(dim=1).to(permuted_hidden_states.dtype) + return output + + +@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) +@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) +@pytest.mark.parametrize("n_expert", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("align_block_size", [None, 128]) +def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, + n_expert: int, ep_size: int, dtype: torch.dtype, + align_block_size: Optional[int]): + fill_invalid_expert = 0 + ep_rank = np.random.randint(0, ep_size) + expert_map = None + n_local_expert = n_expert + if (ep_size != 1): + n_local_expert, expert_map = determine_expert_map( + ep_size, ep_rank, n_expert) + expert_map = expert_map.cuda() + start_expert = n_local_expert * ep_rank + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) + gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, False) + gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( + hidden_states, + topk_ids, + token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + result0, result1, result2, result3 = moe_permute( + hidden_states, topk_weights, topk_ids, token_expert_indices, topk, + n_expert, n_local_expert, expert_map, align_block_size, + fill_invalid_expert) + + # check expert_first_token_offset + torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + # check src_row_id2dst_row_id_map + torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + # check mindice + torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + # check permuted_hidden_states, only valid token + torch.testing.assert_close(gold0[valid_row_idx], + result0[valid_row_idx], + atol=0, + rtol=0) + + # add a random tensor to simulate group gemm + result0 = 0.5 * result0 + torch.randn_like(result0) + + result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, + topk, n_expert, n_local_expert) + gold4 = torch_unpermute(result0, topk_weights, topk_ids, + token_expert_indices, result2, valid_row_idx, topk, + n_local_expert) + + # check unpermuted hidden + torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/quantization/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py index 939b0e7157be7..c30fe60becdfd 100644 --- a/tests/kernels/quantization/test_awq_marlin.py +++ b/tests/kernels/quantization/test_awq_marlin.py @@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, False) marlin_output = torch.ops.vllm.fused_marlin_moe( a, qweight1, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c57e39f425064..38c7e461bb9c4 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, M, K = a.shape N = w2.shape[-1] - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weight, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() @@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 62614a59cbe9a..238808b226f43 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -71,8 +71,8 @@ def single_marlin_moe( E = w.shape[0] N = w.shape[2] // (num_bits // 2) - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) # This might not be an optimal config for a single MMM get_config_func = functools.partial(try_get_optimal_moe_config, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77c..c1edbda0dd224 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -854,7 +854,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -868,20 +868,19 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. topk_func = dispatch_topk_func() topk_weights, topk_ids = topk_func(topk_weights, topk_ids, - token_expert_indicies, + token_expert_indices, gating_output_float, renormalize) - del token_expert_indicies # Not used. Will be used in the future. - return topk_weights, topk_ids + return topk_weights, topk_ids, token_expert_indices # This is used by the Deepseek-V2 and Deepseek-V3 model @@ -1510,8 +1509,8 @@ def fused_moe( topk, renormalize, num_expert_group, topk_group) elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states, gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3cdf3c97a7d3e..35994c8ac6af0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -801,10 +801,11 @@ class FusedMoE(torch.nn.Module): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py new file mode 100644 index 0000000000000..cdf7e31c1436e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + + +def moe_permute( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function expands and permutes activation to gather uncontinuous tokens + for each expert. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - token_expert_indices (torch.Tensor): indice for expanded hidden. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - align_block_size (Optional[int]): align group gemm block size for deepgemm + - fill_invalid_expert(int): fill expert id in m_indices for invalid expert + to workaround DeepGemm unsupported -1 in m_indices + Returns: + - permuted_hidden_states (torch.Tensor): permuted activation. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for standard grouped gemm. if enable 'align_block_size' + expert_first_token_offset will align up to 'align_block_size'. + - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records + the group which the j-th row of the LHS belong to.` + """ + n_token, n_hidden = hidden_states.shape + assert (n_hidden * hidden_states.element_size() + ) % 16 == 0, "permue kernel need hidden dim align to 16B" + permuted_row_size = n_token * topk + if align_block_size is not None: + permuted_row_size = (permuted_row_size + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + m_indices = torch.full((permuted_row_size, ), + fill_invalid_expert, + dtype=torch.int32, + device=hidden_states.device) + expert_first_token_offset = torch.empty(n_local_expert + 1, + dtype=torch.int64, + device=hidden_states.device) + src_row_id2dst_row_id_map = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, + token_expert_indices, expert_map, n_expert, + n_local_expert, topk, align_block_size, + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + return (permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + + +def moe_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + expert_first_token_offset: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, +) -> torch.Tensor: + """ + This function expands and permutes activation to gathering uncontinuous + tokens for each expert. + Parameters: + - permuted_hidden_states (torch.Tensor): permuted activation. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for grouped gemm. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + Returns: + - hidden_states (torch.Tensor): The reduced and unpermuted activation + tensor. + """ + n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1] + assert (n_hidden * permuted_hidden_states.element_size() + ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + hidden_states = torch.empty((n_token, n_hidden), + dtype=permuted_hidden_states.dtype, + device=permuted_hidden_states.device) + + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, + topk_ids, src_row_id2dst_row_id_map, + expert_first_token_offset, n_expert, + n_local_expert, topk, hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index dfe8f20c70d62..c518efdb54f89 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -175,10 +175,8 @@ class ArcticMoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) do_normalize = self.top_k > 1 - topk_weights, topk_ids = fused_topk(hidden_states, - router_logits, - self.top_k, - renormalize=do_normalize) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=do_normalize) # topk_ids: (num_tokens, k) if self.is_quant: if 2 * num_tokens <= self.num_experts: