diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 74ed34d0a474b..2b195b3dd15d0 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | | marlin | standard | 3 | 3 | silu,
swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] | + +| marlin experts | standard | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | @@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| -| deepep_high_throughput,
pplx | `DeepEPHTPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8` | -| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8` | +| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index cf0b965cc8c51..2017a01475b29 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w2.size(1) == K - E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( hidden_states, w1, w2, topk_ids) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 8c2ff580575f5..1578e4822765d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids) + e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) n = w2.shape[2] * 2 run_cutlass_moe_fp4( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e49750bc92b3b..fee628eae4d84 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_num_tokens = expert_tokens_meta.expert_num_tokens - E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( hidden_states, w1, w2, topk_ids) assert w1.size(0) == E diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index eb12a9b0a233f..617d871a5b3d5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -4,11 +4,18 @@ from typing import Optional import torch +from typing_extensions import override import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, maybe_warn_marlin_atomic_add) + marlin_make_workspace_new, marlin_moe_intermediate_size, + maybe_warn_marlin_atomic_add) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import direct_register_custom_op @@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, bias2: Optional[torch.Tensor], w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: torch.Tensor, + gating_output: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, @@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, is_k_full: bool = True, + output: Optional[torch.Tensor] = None, inplace: bool = False) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor, - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). + - gating_output (Optional[torch.Tensor]): The output of the gating + operation (before softmax). - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - sort_indices1 (Optional[torch.Tensor]): The first act_order input @@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[ - 0], "Number of tokens mismatch" + if gating_output is not None: + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( @@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor, M, K = hidden_states.shape E = w1.shape[0] - N = w2.shape[1] * 16 + N = marlin_moe_intermediate_size(w1, w2) topk = topk_ids.shape[1] # M block size selection logic @@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache13 = torch.empty( - (M * topk_ids.shape[1] * max(2 * N, K), ), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] - intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) - intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] - intermediate_cache3 = intermediate_cache3.view(-1, K) + if intermediate_cache2 is None: + intermediate_cache2 = torch.empty( + (M * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if intermediate_cache13 is None: + intermediate_cache13 = torch.empty( + (M * topk * max(2 * N, K), ), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = _resize_cache(intermediate_cache13, + (M * topk, 2 * N)) + intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K)) + intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N)) maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) use_atomic_add = hidden_states.dtype == torch.half or \ @@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, use_fp32_reduce=True, is_zp_float=False).view(-1, topk, K) - output = hidden_states if inplace else torch.empty_like(hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=output) + if output is None: + output = hidden_states if inplace else torch.empty_like(hidden_states) + return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) def fused_marlin_moe_fake(hidden_states: torch.Tensor, @@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: torch.Tensor, + gating_output: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, @@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, w1_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, is_k_full: bool = True, + output: Optional[torch.Tensor] = None, inplace: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -237,3 +254,124 @@ direct_register_custom_op( op_func=fused_marlin_moe, fake_impl=fused_marlin_moe_fake, ) + + +class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" + super().__init__(quant_config) + + @override + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + assert w1.dim() == 3 and w2.dim() == 3 + + E = w1.size(0) + K = a1.size(-1) + N = marlin_moe_intermediate_size(w1, w2) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), \ + f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # Modular Kernel provisions output buffer from workspace1. However in + # the fused_marlin_moe() function, the final torch.sum(), is defined + # essentially as, + # `torch.sum(workspace1, dim=1, out=output)` + # Having overlapping input and output tensors for torch.sum seems + # error prone and depends on how the torch.sum is implemented. + # For this reason we swap let the output buffer provision from + # workspace2. + + # Workspace/IntermediateCache allocation matching fused_marlin_moe() + #workspace1 = (M * topk * max(2 * N, K),) + #workspace2 = (M * topk, N) + + # Workspace/IntermediateCache allocation accounting for output buffer + # provisioning + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk * max(2 * N, K), ) + output = (M, K) + + return (workspace1, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + ): + assert self.w1_scale is not None + assert self.w2_scale is not None + return fused_marlin_moe( + hidden_states=hidden_states, + w1=w1, + w2=w2, + bias1=self.w1_bias, + bias2=self.w2_bias, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale, + gating_output=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map, + output=output, + # Workspaces are swapped in workspace_shapes() to account for proper + # output buffer allocation. Please refer to workspace_shapes(). + intermediate_cache13=workspace2, + intermediate_cache2=workspace13) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 49f278c72007f..f96525734fd9a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] - E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + E, num_tokens, N, K, top_k_num = self.moe_problem_size( hidden_states, w1, w2, topk_ids) if global_num_experts == -1: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b6afc8651e36d..a7617f8b7297d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, # -def _moe_problem_size( - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, -) -> tuple[int, int, int, int, int]: - """ - Extract the MoE problem size from the given tensor arguments: - - a: The hidden states, input to the MoE layer. - - w1: The first set of expert weights. - - w2: The second set of expert weights. - - topk_ids: The topk ids. - - Note: extracting the problem shape from the weight and activation tensors is - not obvious. It needs to be done this way specifically due to subtle issues - with particular kernels, e.g. the int4 kernels divide the trailing dimension - by two, so it's not "correct" to extract N or K from the trailing dimension - of w1 or w2. Similarly, some kernels transpose the weights, so this needs - to be kept in mind. - """ - assert w1.dim() == 3 and w2.dim() == 3 - E, N, _ = w1.size() - K = a1.size(-1) - - if a1.dim() == 2: - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), \ - f"{topk_ids.size(0)} != {a1.size(0)}" - M = a1.size(0) - else: - assert a1.dim() == 3 - assert a1.size(0) == E, f"{a1.size(0)} == {E}" - M = a1.size(1) # This is max_num_tokens - - assert topk_ids.dim() == 2 - topk = topk_ids.size(1) - - return E, M, N, K, topk - - class FusedMoEActivationFormat(Enum): """ The standard activation format (num_tokens, hidden dim). @@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = a1.size(-1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), \ + f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + # # Various helpers for accessing quantization parameters from the # quant_config. @@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module): apply_router_weight_on_input: bool, ) -> torch.Tensor: - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + _, M, N, K, top_k = self.fused_experts.moe_problem_size( + a1q, w1, w2, topk_ids) (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = self.fused_experts.workspace_shapes( @@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module): apply_router_weight_on_input: bool, ) -> torch.Tensor: - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + _, M, N, K, top_k = self.fused_experts.moe_problem_size( + a1q, w1, w2, topk_ids) CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 54194b2e7d5b0..950bf33dbf01d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, mxfp4_w4a16_moe_quant_config) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts @@ -92,7 +93,7 @@ def get_mxfp4_backend(): "Please `pip install vllm[flashinfer]` for best results.") # If FlashInfer is not available, try either Marlin or Triton - if current_platform.get_device_capability( + if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability( )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( "2.8.0"): logger.info_once("Using Marlin backend") @@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: if self.mxfp4_backend == Mxfp4Backend.MARLIN: - return None - - if self.mxfp4_backend == Mxfp4Backend.TRITON: + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config return mxfp4_w4a16_moe_quant_config( @@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): } return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) + elif (self.mxfp4_backend == Mxfp4Backend.MARLIN): + return MarlinExperts(self.moe_quant_config) else: return OAITritonExperts(self.moe_quant_config) @@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") + if self.fused_experts is not None: + return self._route_and_experts( + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ) + if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=activation, expert_map=expert_map) - if self.fused_experts is not None: - return self._route_and_experts( - layer, - x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - expert_load_view, - logical_to_physical_map, - logical_replica_count, - ) - assert _can_support_mxfp4( use_grouped_topk, topk_group, num_expert_group, expert_map, custom_routing_function, e_score_correction_bias, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 317ad079b392d..6c7604cc9d048 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ supports_router_weight and supports_activation +def marlin_moe_intermediate_size(w1_packed: torch.Tensor, + w2_packed: torch.Tensor): + """ + Given Marlin packed weight matrices w1_packed, and w2_packed, + return the MoE intermediate size N + """ + marlin_tile_size = 16 + return w2_packed.size(1) * marlin_tile_size + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition //