diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 629348bf88764..b3d0c0aa58e9e 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -8,12 +8,77 @@ #include "../cuda_compat.h" #include "../dispatch_utils.h" +#include "core/math.hpp" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace moe { +namespace batched_moe_align_block_size { + +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. +static constexpr int32_t num_threads = 1024; +static constexpr int32_t num_blocks = 1; +__global__ void batched_moe_align_block_size_kernel( + int32_t const num_batches, int32_t const max_tokens_per_batch, + int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t* __restrict__ num_tokens_post_pad) { + // TODO(varun): This is a naive implementation. Could be optimized. + + size_t const batch_id = threadIdx.x; + size_t const stride = blockDim.x * gridDim.x; + int32_t const num_blocks_per_batch = + CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = + num_blocks_per_batch * num_batches * block_size; + int32_t const block_ids_size = sorted_ids_size / block_size; + int32_t const SENTINEL = + num_batches * max_tokens_per_batch; // To denote invalid entries. + // Intialize sorted_ids + for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { + sorted_ids[i] = SENTINEL; + } + // Intialize expert_ids with -1 + for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { + block_ids[i] = -1; + } + + int32_t b_num_tokens = 0; + if (batch_id < num_batches) { + b_num_tokens = batch_num_tokens[batch_id]; + } + int32_t const ceil_b_num_tokens = + CEILDIV(b_num_tokens, block_size) * block_size; + + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val); + __syncthreads(); + + bool const is_last_batch = batch_id == (num_batches - 1); + if (is_last_batch) { + *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; + } + + if (batch_id < num_batches) { + int32_t const batch_offset = batch_id * max_tokens_per_batch; + for (size_t i = 0; i < b_num_tokens; ++i) { + sorted_ids[cumsum_val + i] = batch_offset + i; + } + + int32_t const block_start = cumsum_val / block_size; + int32_t const num_blocks = ceil_b_num_tokens / block_size; + for (size_t i = 0; i < num_blocks; ++i) { + block_ids[block_start + i] = batch_id; + } + } +} +} // namespace batched_moe_align_block_size + template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, @@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, }); } +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { + namespace batched_kernel = vllm::moe::batched_moe_align_block_size; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int32_t const B = batch_num_tokens.size(0); + int32_t const num_blocks_per_batch = + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks = num_blocks_per_batch * B; + int64_t const sorted_ids_size = num_blocks * block_size; + + TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); + TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); + TORCH_CHECK(num_tokens_post_pad.size(0) == 1); + TORCH_CHECK(B <= batched_kernel::num_threads); + + batched_kernel::batched_moe_align_block_size_kernel<<< + batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( + B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); +} + void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 92fc280b362b9..86d9cc1848fff 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& expert_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad); + #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8f33d6cd666fa..2c0a515ef643e 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size, but for the batched case. + m.def( + "batched_moe_align_block_size(int max_tokens_per_batch," + " int block_size, Tensor expert_num_tokens," + " Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("batched_moe_align_block_size", torch::kCUDA, + &batched_moe_align_block_size); + #ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0831c5bc790dc..633e23eea33e2 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | -| marlin | standard | 3 | 3 | silu,
swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] | -| marlin experts | standard | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] | +| marlin | standard | 3 | 3 | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | +| marlin experts | standard,
batched | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | @@ -115,6 +115,6 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| -| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | -| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| -| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | +| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts`| +| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 966e2f8f3b131..2c802ff4e6bd6 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -7,6 +7,8 @@ Run `pytest tests/kernels/test_moe.py`. import functools from collections.abc import Callable +from dataclasses import dataclass +from typing import Any import pytest import torch @@ -26,7 +28,10 @@ from vllm.model_executor.layers.fused_moe.config import ( int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + batched_fused_marlin_moe, + fused_marlin_moe, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe, @@ -564,6 +569,105 @@ def marlin_moe_generate_valid_test_cases(): return cases +@dataclass +class MarlinMoEWeightData: + w_ref: torch.Tensor + qweight: torch.Tensor + scales: torch.Tensor + global_scale: torch.Tensor | None + g_idx: torch.Tensor | None + zeros: torch.Tensor | None + sort_indices: torch.Tensor | None + marlin_bias: torch.Tensor | None + + @staticmethod + def make( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool | None = None, + bias: torch.Tensor | None = None, + ) -> "MarlinMoEWeightData": + assert w.ndim == 3 + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + k = w.shape[-1] + + w_ref_l: list[torch.Tensor] = [] + qweight_l: list[torch.Tensor] = [] + scales_l: list[torch.Tensor] = [] + global_scale_l: list[torch.Tensor] = [] + zeros_l: list[torch.Tensor] = [] + g_idx_l: list[torch.Tensor] = [] + sort_indices_l: list[torch.Tensor] = [] + bias_l: list[torch.Tensor] = [] + + for i in range(w.shape[0]): + if quant_type == scalar_types.float4_e2m1f: + if group_size == 16: + w_ref, qweight, scales, global_scale = ( + rand_marlin_weight_nvfp4_like(w[i], group_size) + ) + else: + w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( + w[i], group_size + ) + global_scale = None + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + if global_scale is not None: + global_scale_l.append(global_scale) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + elif has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size + ) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + if bias is not None: + bias_l.append(marlin_permute_bias(bias[i])) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweight_l).contiguous() + scales = stack_and_dev(scales_l) + global_scale = stack_and_dev(global_scale_l) if global_scale_l else None + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None + marlin_bias = stack_and_dev(bias_l) if bias_l else None + + return MarlinMoEWeightData( + w_ref=w_ref, + qweight=qweight, + scales=scales, + global_scale=global_scale, + g_idx=g_idx, + zeros=zeros, + sort_indices=sort_indices, + marlin_bias=marlin_bias, + ) + + @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize( ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), @@ -584,7 +688,6 @@ def test_fused_marlin_moe( is_k_full: bool, ): torch.cuda.manual_seed(0) - has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 @@ -600,152 +703,44 @@ def test_fused_marlin_moe( else: e_map = None - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - global_scale1_l = [] - zeros1_l = [] - g_idx1_l = [] - sort_indices1_l = [] + w1_data = MarlinMoEWeightData.make( + w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order + ) - for i in range(w1.shape[0]): - if quant_type == scalar_types.float4_e2m1f: - if group_size == 16: - w_ref1, qweight1, scales1, global_scale1 = ( - rand_marlin_weight_nvfp4_like(w1[i], group_size) - ) - else: - w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like( - w1[i], group_size - ) - global_scale1 = None - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - if global_scale1 is not None: - global_scale1_l.append(global_scale1) - elif quant_type == scalar_types.float8_e4m3fn: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size) - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - elif has_zp: - w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - zeros1_l.append(zeros1) - else: - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - global_scale2_l = [] - zeros2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - - for i in range(w2.shape[0]): - if quant_type == scalar_types.float4_e2m1f: - if group_size == 16: - w_ref2, qweight2, scales2, global_scale2 = ( - rand_marlin_weight_nvfp4_like(w2[i], group_size) - ) - else: - w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like( - w2[i], group_size - ) - global_scale2 = None - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - if global_scale2 is not None: - global_scale2_l.append(global_scale2) - elif quant_type == scalar_types.float8_e4m3fn: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - elif has_zp: - w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size - ) - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - zeros2_l.append(zeros2) - else: - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + w2_data = MarlinMoEWeightData.make( + w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) + torch_output = torch_moe( + a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map + ) marlin_output = fused_marlin_moe( a, - qweight1, - qweight2, + w1_data.qweight, + w2_data.qweight, None, None, - scales1, - scales2, + w1_data.scales, + w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=e_map, - global_scale1=global_scale1, - global_scale2=global_scale2, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) @@ -773,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m): b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 - b_bias1_l = [] - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - g_idx1_l = [] - sort_indices1_l = [] + w1_data = MarlinMoEWeightData.make( + w=w1, + quant_type=quant_type, + group_size=group_size, + act_order=act_order, + bias=b_bias1, + ) - for i in range(w1.shape[0]): - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - b_bias1_l.append(marlin_permute_bias(b_bias1[i])) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - global_scale1 = None - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None - - b_bias2_l = [] - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - - for i in range(w2.shape[0]): - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) - b_bias2_l.append(marlin_permute_bias(b_bias2[i])) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - global_scale2 = None - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None - marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None + w2_data = MarlinMoEWeightData.make( + w=w2, + quant_type=quant_type, + group_size=group_size, + act_order=act_order, + bias=b_bias2, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2) + torch_output = torch_moe( + a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2 + ) marlin_output = fused_marlin_moe( a, - qweight1, - qweight2, - marlin_bias1, - marlin_bias2, - scales1, - scales2, + w1_data.qweight, + w2_data.qweight, + w1_data.marlin_bias, + w2_data.marlin_bias, + w1_data.scales, + w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=None, - global_scale1=global_scale1, - global_scale2=global_scale2, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) @@ -895,6 +850,41 @@ def test_moe_align_block_size_opcheck(): ) +def test_batched_moe_align_block_size_opcheck(): + max_tokens_per_batch = 512 + num_experts = 4 + block_size = 16 + + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts,), + dtype=torch.int32, + device="cuda", + ) + + max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + + opcheck( + torch.ops._moe_C.batched_moe_align_block_size, + ( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) + + @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -979,3 +969,171 @@ def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation): else: atol = 5e-2 torch.testing.assert_close(out, ref, atol=atol, rtol=0) + + +@pytest.mark.parametrize("m", [16, 32, 64]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8, 12, 16, 32]) +@pytest.mark.parametrize("topk", [2, 4]) +@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_batched_fused_marlin_moe( + m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int +): + print( + f"testing m={m}, n={n}, k={k}, e={e}, " + f"topk={topk}, " + f"max_tokens_per_batch={max_tokens_per_batch}" + ) + torch.cuda.manual_seed(0) + + dtype = torch.bfloat16 + quant_dtype = scalar_types.float4_e2m1f + group_size = 32 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 + + w1_data = MarlinMoEWeightData.make( + w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None + ) + w2_data = MarlinMoEWeightData.make( + w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + class BatchedRun: + @staticmethod + def _make_expert_num_tokens_cpu( + e: int, # num_experts + topk_ids_cpu: torch.Tensor, + ) -> torch.Tensor: + expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu") + for topk_id in torch.flatten(topk_ids_cpu): + expert_num_tokens_cpu[topk_id] += 1 + return expert_num_tokens_cpu + + def __init__( + self, + max_tokens_per_batch: int, + num_experts: int, + _topk_ids: torch.Tensor, + _topk_weights: torch.Tensor, + ): + self.max_tokens_per_batch = max_tokens_per_batch + self.e = num_experts + self.topk_ids_cpu = _topk_ids.to("cpu") + self.topk_weights_cpu = _topk_weights.to("cpu") + self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu( + self.e, self.topk_ids_cpu + ) + + def is_valid(self): + """ + Return True only if the input can be represented in a Batched + format. + """ + return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch) + + def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_cpu = hidden_states.to("cpu") + K = hidden_states_cpu.size(1) + batched_hidden_states_cpu = torch.empty( + (e, max_tokens_per_batch, K), + dtype=hidden_states_cpu.dtype, + device="cpu", + ) + + counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu) + for t_idx, token in enumerate(hidden_states_cpu): + for topk_id in self.topk_ids_cpu[t_idx]: + pos_in_batch = counter_cpu[topk_id] + batched_hidden_states_cpu[topk_id, pos_in_batch] = token + counter_cpu[topk_id] += 1 + assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu) + return batched_hidden_states_cpu.to("cuda") + + def _gather( + self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor + ) -> torch.Tensor: + batched_outputs_cpu = batched_outputs.to("cpu") + gather_outputs_cpu = torch.zeros_like(gather_outputs) + + counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32) + md = gather_outputs_cpu.size(0) + for t_idx in range(md): + token = None + for topk_id, topk_weight in zip( + self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx] + ): + pos_in_batch = counter_cpu[topk_id] + t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight + if token is None: + token = t + else: + token += t + counter_cpu[topk_id] += 1 + assert token is not None + gather_outputs_cpu[t_idx] = token + gather_outputs.copy_(gather_outputs_cpu) + return gather_outputs + + def run( + self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any] + ) -> torch.Tensor: + assert hidden_states.ndim == 2 + assert self.is_valid() + + batched_hidden_states = self._scatter(hidden_states) + + kwargs = fused_marlin_moe_kwargs | { + "hidden_states": batched_hidden_states, + "expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"), + } + batched_outputs = batched_fused_marlin_moe(**kwargs) + + output = torch.zeros_like(hidden_states) + output = self._gather(batched_outputs, output) + return output + + kwargs = { + "w1": w1_data.qweight, + "w2": w2_data.qweight, + "bias1": None, + "bias2": None, + "w1_scale": w1_data.scales, + "w2_scale": w2_data.scales, + "gating_output": score, + "global_num_experts": e, + "expert_map": None, + "global_scale1": w1_data.global_scale, + "global_scale2": w2_data.global_scale, + "g_idx1": w1_data.g_idx, + "g_idx2": w2_data.g_idx, + "sort_indices1": w1_data.sort_indices, + "sort_indices2": w2_data.sort_indices, + "w1_zeros": w1_data.zeros, + "w2_zeros": w2_data.zeros, + "quant_type_id": quant_dtype.id, + "is_k_full": True, + } + + # Reference + fused_marlin_moe_kwargs = kwargs | { + "hidden_states": a, + "topk_ids": topk_ids, + "topk_weights": topk_weights, + } + ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs) + + # Batched + br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights) + if not br.is_valid(): + pytest.skip("Cannot represent data in Batched Format.") + marlin_output = br.run(a, kwargs) + + torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 6f779c6950150..bde0478d9c18d 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -9,6 +9,7 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + batched_moe_align_block_size, moe_align_block_size, ) from vllm.platforms import current_platform @@ -300,3 +301,96 @@ def test_moe_align_block_size_deterministic(): assert torch.equal(results[0][2], results[i][2]), ( "num_tokens should be deterministic" ) + + +@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512]) +@pytest.mark.parametrize("num_experts", [8, 16, 32, 64]) +@pytest.mark.parametrize("block_size", [8, 16, 32, 64]) +@pytest.mark.parametrize("simulate_empty_batches", [False, True]) +def test_batched_moe_align_block_size( + max_tokens_per_batch: int, + num_experts: int, + block_size: int, + simulate_empty_batches: bool, +): + def ref_outputs( + expert_num_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + E = expert_num_tokens.size(0) + + # Round up so each batch can be split to blocks evenly. + Msum = round_up(max_tokens_per_batch, block_size) * E + ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32) + ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32) + ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32) + + # Intialize + sentinel = E * max_tokens_per_batch + ref_sorted_ids.fill_(sentinel) + ref_expert_ids.fill_(-1) + + # Fill ref_sorted_ids + i = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + token_offset = expert_id * max_tokens_per_batch + for j in range(expert_nt): + ref_sorted_ids[i] = token_offset + j + i += 1 + # round up i to the next block_size + i = round_up(i, block_size) + + ref_num_tokens_post_pad[0] = i + + # Fill expert_ids + nt_ceil_sum = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + expert_ids_offset = nt_ceil_sum // block_size + ceil_expert_nt = round_up(int(expert_nt.item()), block_size) + num_blocks = ceil_expert_nt // block_size + for x in range(num_blocks): + ref_expert_ids[expert_ids_offset + x] = expert_id + nt_ceil_sum += ceil_expert_nt + + return ( + ref_sorted_ids.to("cuda"), + ref_expert_ids.to("cuda"), + ref_num_tokens_post_pad.to("cuda"), + ) + + # Compute expert_num_tokens + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts,), + device="cpu", + dtype=torch.int32, + ) + if simulate_empty_batches: + # mark half the batches to have 0 tokens + zero_batches = torch.randperm(num_experts)[: num_experts // 2] + expert_num_tokens[zero_batches] = 0 + + # ref outputs + ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs( + expert_num_tokens + ) + + # outputs + sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size( + max_tokens_per_batch, block_size, expert_num_tokens.to("cuda") + ) + + assert ref_sorted_ids.size() == sorted_ids.size(), ( + f"{ref_sorted_ids.size()} vs {sorted_ids.size()}" + ) + assert ref_expert_ids.size() == expert_ids.size(), ( + f"{ref_expert_ids.size()} vs {expert_ids.size()}" + ) + assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), ( + f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}" + ) + torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0) + torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0) + torch.testing.assert_close( + ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0 + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f1ed3bac80c60..dbbfc01e3bb4d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1789,6 +1789,24 @@ def moe_align_block_size( ) +def batched_moe_align_block_size( + max_tokens_per_batch: int, + block_size: int, + expert_num_tokens: torch.Tensor, + sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + def moe_wna16_gemm( input: torch.Tensor, output: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b3ba2e308953a..a6558c56db845 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP low-latency kernels are compiled only for certain # specific hidden sizes. - SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168] + # NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends + # on it. + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168, 8192] + + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int) -> int: + # Round up hidden size to the closest supported hidden size. + _supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES + # Check sorted + num_supported_hs = len(_supported_hs) + assert all( + [ + _supported_hs[i] < _supported_hs[i + 1] + for i in range(num_supported_hs - 1) + ] + ) + + for x in _supported_hs: + if x >= hidden_size: + return x + + raise ValueError( + f"Hidden Size {hidden_size} is greater than the " + f"maximum supported hidden size {_supported_hs[-1]}" + ) def __init__( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 57e17f324d2e8..e457b729da8c5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -3,13 +3,16 @@ """Fused MoE utilities for GPTQ.""" import torch -from typing_extensions import override import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + batched_moe_align_block_size, + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace @@ -21,6 +24,153 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.scalar_type import ScalarType, scalar_types +def _fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: torch.Tensor | None, + bias2: torch.Tensor | None, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + num_topk: int, + quant_type: ScalarType, + apply_router_weight_on_input: bool, + activation: str, + expert_map: torch.Tensor | None, + block_size_m: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + global_scale1: torch.Tensor | None = None, + global_scale2: torch.Tensor | None = None, + g_idx1: torch.Tensor | None = None, + g_idx2: torch.Tensor | None = None, + sort_indices1: torch.Tensor | None = None, + sort_indices2: torch.Tensor | None = None, + w1_zeros: torch.Tensor | None = None, + w2_zeros: torch.Tensor | None = None, + workspace: torch.Tensor | None = None, + intermediate_cache13: torch.Tensor | None = None, + intermediate_cache2: torch.Tensor | None = None, + output: torch.Tensor | None = None, + is_k_full: bool = True, +) -> torch.Tensor: + assert hidden_states.ndim == 2 + M, K = hidden_states.size() + N = marlin_moe_intermediate_size(w1, w2) + + if workspace is None: + workspace = marlin_make_workspace_new(hidden_states.device, 4) + + if intermediate_cache13 is None: + intermediate_cache13 = torch.empty( + (M * num_topk * max(2 * N, K),), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if intermediate_cache2 is None: + intermediate_cache2 = torch.empty( + (M * num_topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N)) + + intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K)) + + intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N)) + + maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) + use_atomic_add = ( + hidden_states.dtype == torch.half + or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + ) + + intermediate_cache1 = ops.moe_wna16_marlin_gemm( + hidden_states, + intermediate_cache1, + w1, + bias1, + w1_scale, + global_scale1, + w1_zeros, + g_idx1, + sort_indices1, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=num_topk, + mul_topk_weights=apply_router_weight_on_input, + is_ep=expert_map is not None, + b_q_type=quant_type, + size_m=M, + size_n=2 * N, + size_k=K, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + if activation == "silu": + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) + else: + raise ValueError( + f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported." + ) + + if output is None: + output = intermediate_cache3 + + if expert_map is not None: + output.zero_() + + output = ops.moe_wna16_marlin_gemm( + intermediate_cache2, + output, + w2, + bias2, + w2_scale, + global_scale2, + w2_zeros, + g_idx2, + sort_indices2, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + topk_weights, + moe_block_size=block_size_m, + top_k=1, + mul_topk_weights=not apply_router_weight_on_input, + is_ep=expert_map is not None, + b_q_type=quant_type, + size_m=M * num_topk, + size_n=K, + size_k=N, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=True, + is_zp_float=False, + ) + + return output + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -62,23 +212,27 @@ def fused_marlin_moe( - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - - gating_output (Optional[torch.Tensor]): The output of the gating + - gating_output (torch.Tensor|None): The output of the gating operation (before softmax). - - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - - sort_indices1 (Optional[torch.Tensor]): The first act_order input + - g_idx1 (torch.Tensor|None): The first set of act_order indices. + - g_idx2 (torch.Tensor|None): The second set of act_order indices. + - sort_indices1 (torch.Tensor|None): The first act_order input permutation. - - sort_indices2 (Optional[torch.Tensor]): The second act_order input + - sort_indices2 (torch.Tensor|None): The second act_order input permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. + - w1_zeros (torch.Tensor|None): Optional zero points to be used for w1. + - w2_zeros (torch.Tensor|None): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + + if inplace: + assert output is None, "Conflicting request" + quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ scalar_types.uint4, @@ -95,15 +249,15 @@ def fused_marlin_moe( ] num_bits = 4 if quant_type in bit4_scalar_types else 8 + M, K = hidden_states.size() + E = w1.size(0) + topk = topk_ids.size(1) + # Check constraints. if gating_output is not None: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch" - ) - assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), ( - "Hidden size mismatch w2" - ) + assert gating_output.size(0) == M, "Number of tokens mismatch" + assert w1.size(1) * 16 == K, "Hidden size mismatch w1" + assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" @@ -111,11 +265,6 @@ def fused_marlin_moe( assert num_bits in [4, 8] assert topk_weights.dtype == torch.float32 - M, K = hidden_states.shape - E = w1.shape[0] - N = marlin_moe_intermediate_size(w1, w2) - topk = topk_ids.shape[1] - # M block size selection logic # TODO: tune this further for specific models for block_size_m in [8, 16, 32, 48, 64]: @@ -128,107 +277,38 @@ def fused_marlin_moe( topk_ids, block_size_m, global_num_experts, expert_map ) - if workspace is None: - workspace = marlin_make_workspace_new(hidden_states.device, 4) - - if intermediate_cache2 is None: - intermediate_cache2 = torch.empty( - (M * topk, N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - if intermediate_cache13 is None: - intermediate_cache13 = torch.empty( - (M * topk * max(2 * N, K),), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N)) - intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K)) - intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N)) - - maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) - use_atomic_add = ( - hidden_states.dtype == torch.half - or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - ) - - intermediate_cache1 = ops.moe_wna16_marlin_gemm( - hidden_states, - intermediate_cache1, - w1, - bias1, - w1_scale, - global_scale1, - w1_zeros, - g_idx1, - sort_indices1, - workspace, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - topk_weights, - moe_block_size=block_size_m, - top_k=topk, - mul_topk_weights=apply_router_weight_on_input, - is_ep=expert_map is not None, - b_q_type=quant_type, - size_m=M, - size_n=2 * N, - size_k=K, + assert activation is not None + moe_output = _fused_marlin_moe( + hidden_states=hidden_states, + w1=w1, + w2=w2, + bias1=bias1, + bias2=bias2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + num_topk=topk, + quant_type=quant_type, + apply_router_weight_on_input=apply_router_weight_on_input, + activation=activation, + expert_map=expert_map, + block_size_m=block_size_m, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=w1_zeros, + w2_zeros=w2_zeros, + workspace=workspace, + intermediate_cache13=intermediate_cache13, + intermediate_cache2=intermediate_cache2, + output=None, is_k_full=is_k_full, - use_atomic_add=use_atomic_add, - use_fp32_reduce=True, - is_zp_float=False, - ) - - if activation == "silu": - torch.ops._C.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - else: - raise ValueError( - f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported." - ) - - if expert_map is not None: - intermediate_cache3.zero_() - - intermediate_cache3 = ops.moe_wna16_marlin_gemm( - intermediate_cache2, - intermediate_cache3, - w2, - bias2, - w2_scale, - global_scale2, - w2_zeros, - g_idx2, - sort_indices2, - workspace, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - topk_weights, - moe_block_size=block_size_m, - top_k=1, - mul_topk_weights=not apply_router_weight_on_input, - is_ep=expert_map is not None, - b_q_type=quant_type, - size_m=M * topk, - size_n=K, - size_k=N, - is_k_full=is_k_full, - use_atomic_add=use_atomic_add, - use_fp32_reduce=True, - is_zp_float=False, ).view(-1, topk, K) if output is None: @@ -237,16 +317,173 @@ def fused_marlin_moe( else: output = torch.empty_like(hidden_states) - return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) + return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) -class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): +def batched_fused_marlin_moe( + hidden_states: torch.Tensor, + expert_num_tokens: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: torch.Tensor | None, + bias2: torch.Tensor | None, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor | None, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + activation: str | None = "silu", + expert_map: torch.Tensor | None = None, + global_scale1: torch.Tensor | None = None, + global_scale2: torch.Tensor | None = None, + g_idx1: torch.Tensor | None = None, + g_idx2: torch.Tensor | None = None, + sort_indices1: torch.Tensor | None = None, + sort_indices2: torch.Tensor | None = None, + w1_zeros: torch.Tensor | None = None, + w2_zeros: torch.Tensor | None = None, + workspace: torch.Tensor | None = None, + intermediate_cache13: torch.Tensor | None = None, + intermediate_cache2: torch.Tensor | None = None, + is_k_full: bool = True, + output: torch.Tensor | None = None, + inplace: bool = False, +) -> torch.Tensor: + """ + This function massages the inputs so the batched hidden_states can be + presented as a 2D contiguous tensor that could be used with + _fused_marlin_moe. + + Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately + use `ops.moe_wna16_marlin_gemm` for the gemm operation and + `ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states. + Note that the moe_align_block_size function indicates, + - What rows of the A matrix (hidden_states) to access during the + matmul, via sorted_ids output. + - What expert_id to use for each block matmul, via expert_ids ouptut. + + In the batched version, the tokens are already grouped/batched by experts + they subscribe to. Due to this, we can represent the batched hidden_states + tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape, + [B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor + with topk=1 as each token (row in the tensor) subscribes to exactly one + expert_id (which is the batch_id). With the expert_num_tokens tensor, that + indicates how many tokens are actually valid in each batch, the + batched_moe_align_block_size function constructs the sorted_ids and + expert_ids tensors, so only relevant/valid rows of A (hidden_states) + are accessed and are processed with the correct expert_ids. + """ + + assert hidden_states.ndim == 3, ( + f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]." + f"But got {hidden_states.size()}" + ) + if inplace: + assert output is None, "Conflicting request." + + quant_type = ScalarType.from_id(quant_type_id) + assert quant_type in [ + scalar_types.uint4, + scalar_types.uint8b128, + scalar_types.uint4b8, + scalar_types.float8_e4m3fn, + scalar_types.float4_e2m1f, + ] + + bit4_scalar_types = [ + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.float4_e2m1f, + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 + + B, BATCH_TOKENS_MAX, K = hidden_states.size() + M = hidden_states.view(-1, K).size(0) + E = w1.size(0) + + # Check constraints. + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + assert expert_num_tokens.size(0) == E + assert B == E, ( + "Batch must be as big as number of experts as the tokens" + "are sorted into the batch/expert they belong to" + ) + assert w1.size(1) * 16 == K, "Hidden size mismatch w1" + assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert num_bits in [4, 8] + + # Technically, the tokens are already separated by their expert ids. + # Hidden-States can just be squeezed to have just 2 dimensions, + # [B * MAX_TOKENS, K] and top_k can be interpreted as just 1. + topk = 1 + + # TODO(varun) : Choose a decent block size like in fused_marlin_moe + block_size_m = 64 + + sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size( + max_tokens_per_batch=BATCH_TOKENS_MAX, + block_size=block_size_m, + expert_num_tokens=expert_num_tokens, + ) + + if output is None and inplace: + output = hidden_states + + # TODO (varun): This can be avoided by plumbing the marlin kernel to + # ignore topk_weights when topk_weights_ptr is a nullptr. + topk_weights = torch.ones( + (M, topk), device=hidden_states.device, dtype=torch.float32 + ) + + assert activation is not None + output = _fused_marlin_moe( + hidden_states=hidden_states.view(-1, K), + w1=w1, + w2=w2, + bias1=bias1, + bias2=bias2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + num_topk=topk, + quant_type=quant_type, + apply_router_weight_on_input=apply_router_weight_on_input, + activation=activation, + expert_map=expert_map, + block_size_m=block_size_m, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=w1_zeros, + w2_zeros=w2_zeros, + workspace=workspace, + intermediate_cache13=intermediate_cache13, + intermediate_cache2=intermediate_cache2, + output=output.view(-1, K) if output is not None else output, + is_k_full=is_k_full, + ) + + output = output.view(B, BATCH_TOKENS_MAX, K) + + return output + + +class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - @override def moe_problem_size( self, a1: torch.Tensor, @@ -274,6 +511,11 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): return E, M, N, K, topk + +class MarlinExperts(MarlinExpertsBase): + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + def supports_expert_map(self) -> bool: return True @@ -365,3 +607,90 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): intermediate_cache13=workspace2, intermediate_cache2=workspace13, ) + + +class BatchedMarlinExperts(MarlinExpertsBase): + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceDelegate() + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) + + def supports_chunking(self) -> bool: + return False + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + num_dispatchers = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2)) + workspace2 = (num_experts * max_num_tokens * num_dispatchers, N) + output = (num_experts, max_num_tokens * num_dispatchers, K) + return (workspace13, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert expert_tokens_meta is not None, "Num valid tokens per batch is required" + return batched_fused_marlin_moe( + hidden_states=hidden_states, + expert_num_tokens=expert_tokens_meta.expert_num_tokens, + w1=w1, + w2=w2, + bias1=self.w1_bias, + bias2=self.w2_bias, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale, + gating_output=None, + quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map, + output=output, + intermediate_cache13=workspace13, + intermediate_cache2=workspace2, + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index de4ed58e0cf4b..3bb544a49f3a9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -994,6 +994,11 @@ def maybe_roundup_hidden_size( hidden_size, act_dtype ) + if moe_parallel_config.use_deepep_ll_kernels: + hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": from vllm.model_executor.layers.quantization.mxfp4 import ( diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index a0d14bdf607e7..f4d8a86c058a8 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -83,3 +83,92 @@ def moe_align_block_size( expert_ids = expert_map[expert_ids] return sorted_ids, expert_ids, num_tokens_post_pad + + +def batched_moe_align_block_size( + max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given num_batches, max_tokens_per_batch, block_size and the number of + valid-tokens in each batch, prepare sorted_token_ids, expert_ids and + num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad + have the same semantics as in moe_align_block_size. + + This function is intended to be a drop in replacement for + moe_align_batch_size for the batched case. + + Parameters: + - max_tokens_per_batch (int): Number of tokens in each batch (both + valid and invalid). + - block_size (int): block_size to align the data to. + - expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates + the number of valid tokens in batch i. + + Returns: + - sorted_token_ids (torch.Tensor): Torch tensor of size + (num_batches * max_tokens_per_batch) indicating the token indices for + that block. + - expert_ids (torch.Tensor): Torch tensor of size + ceil((num_batches * max_tokens_per_batch) / block_size) indicating + what expert to use for each block. + - num_tokens_post_pad (torch.Tensor): Torch tensor of size 1 + indicating the number of valid blocks with actual data to + process. This is represented in terms of num tokens. + Example: + Let num_batches=5, max_tokens_per_batch=8, block_size=4, and + expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor + indicates that, + - The first 2 tokens in the 0th batch are valid and the rest 6 are + invalid (i.e. in the 2D hidden_states tensor of shape, + [num_batches * max_tokens_per_batch, K], indices 0, 1 are valid) + - The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10 + - 0 tokens in the 2nd batch are valid + - first 6 tokens in the 3rd batch are valid. i.e. indices, + 24, 25, 26, 27, 28, 29 + - so on ... + + In this case, + sorted_token_ids will be [0, 1, 40, 40, + 8, 9, 10, 40, + 24, 25, 26, 27, + 28, 29, 40, 40, + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 40, 40, 40, + (rest all 40, 40, 40, 40) + ...] + Here, 40 represents an invalid index. as there is no token index 40. + The gemm kernel using this sorted_token_ids is expected to skip the + gemm computation when it encounters this invalid index. + + expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...] + Here, -1 represents an invalid expert. The gemm kernel using this + expert_ids is expected to skip the gemm computation when it encounters + an expert of id -1. + + num_tokens_post_pad will be 24 as sorted_token_ids has valid entries + until 24. + """ + + B = expert_num_tokens.size(0) + device = expert_num_tokens.device + + # Round up so each batch can be split to blocks evenly. + max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device) + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device) + + ops.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a7f9fdcb5513e..2eda2abfb40bd 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, MarlinExperts, fused_marlin_moe, ) @@ -797,9 +798,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts ): - raise NotImplementedError( - "Mxfp4 does not support batched experts format for EP" - ) + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + assert self.moe_quant_config is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + raise NotImplementedError( + "Incompatible Mxfp4 backend for EP batched experts format" + ) else: assert self.moe_quant_config is not None if (