diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 74ed34d0a474b..2b195b3dd15d0 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -93,6 +93,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton2 | standard,batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard | 3 | 3 | silu,swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
+
+| marlin experts | standard | N/A | N/A | silu,swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
@@ -114,6 +116,6 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
-| deepep_high_throughput,pplx | `DeepEPHTPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`BatchedTritonOrDeepGemmExperts`,`CutlassBatchedExpertsFp8` |
-| deepep_low_latency | `DeepEPLLPrepareAndFinalize` | `DeepGemmExperts`,`TritonExperts`,`TritonOrDeepGemmExperts`,`CutlassExpertsFp8` |
+| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,`TritonExperts`,`TritonOrDeepGemmExperts`,`CutlassExpertsFp8`, `MarlinExperts` |
+| deepep_low_latency,pplx | `DeepEPLLPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`BatchedTritonOrDeepGemmExperts`,`CutlassBatchedExpertsFp8`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index cf0b965cc8c51..2017a01475b29 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -303,7 +303,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert w2.size(1) == K
- E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
+ E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index 8c2ff580575f5..1578e4822765d 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -712,7 +712,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
- e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
+ e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4(
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
index e49750bc92b3b..fee628eae4d84 100644
--- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -906,7 +906,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens = expert_tokens_meta.expert_num_tokens
- E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
+ E, max_num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
assert w1.size(0) == E
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index eb12a9b0a233f..617d871a5b3d5 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -4,11 +4,18 @@
from typing import Optional
import torch
+from typing_extensions import override
import vllm._custom_ops as ops
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceNoOP)
+from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
- marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
+ marlin_make_workspace_new, marlin_moe_intermediate_size,
+ maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
@@ -20,7 +27,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
bias2: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
- gating_output: torch.Tensor,
+ gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
@@ -37,7 +44,10 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
+ intermediate_cache13: Optional[torch.Tensor] = None,
+ intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True,
+ output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -49,8 +59,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- - gating_output (torch.Tensor): The output of the gating operation
- (before softmax).
+ - gating_output (Optional[torch.Tensor]): The output of the gating
+ operation (before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
@@ -78,8 +88,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
num_bits = 4 if quant_type in bit4_scalar_types else 8
# Check constraints.
- assert hidden_states.shape[0] == gating_output.shape[
- 0], "Number of tokens mismatch"
+ if gating_output is not None:
+ assert hidden_states.shape[0] == gating_output.shape[
+ 0], "Number of tokens mismatch"
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
@@ -93,7 +104,7 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
M, K = hidden_states.shape
E = w1.shape[0]
- N = w2.shape[1] * 16
+ N = marlin_moe_intermediate_size(w1, w2)
topk = topk_ids.shape[1]
# M block size selection logic
@@ -111,20 +122,24 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if workspace is None:
workspace = marlin_make_workspace_new(hidden_states.device, 4)
- intermediate_cache2 = torch.empty(
- (M * topk_ids.shape[1], N),
- device=hidden_states.device,
- dtype=hidden_states.dtype,
- )
- intermediate_cache13 = torch.empty(
- (M * topk_ids.shape[1] * max(2 * N, K), ),
- device=hidden_states.device,
- dtype=hidden_states.dtype,
- )
- intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
- intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
- intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
- intermediate_cache3 = intermediate_cache3.view(-1, K)
+ if intermediate_cache2 is None:
+ intermediate_cache2 = torch.empty(
+ (M * topk, N),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ if intermediate_cache13 is None:
+ intermediate_cache13 = torch.empty(
+ (M * topk * max(2 * N, K), ),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ intermediate_cache1 = _resize_cache(intermediate_cache13,
+ (M * topk, 2 * N))
+ intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
+ intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
use_atomic_add = hidden_states.dtype == torch.half or \
@@ -200,10 +215,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
use_fp32_reduce=True,
is_zp_float=False).view(-1, topk, K)
- output = hidden_states if inplace else torch.empty_like(hidden_states)
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
- dim=1,
- out=output)
+ if output is None:
+ output = hidden_states if inplace else torch.empty_like(hidden_states)
+ return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
def fused_marlin_moe_fake(hidden_states: torch.Tensor,
@@ -211,7 +225,7 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
- gating_output: torch.Tensor,
+ gating_output: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
@@ -227,7 +241,10 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
+ intermediate_cache13: Optional[torch.Tensor] = None,
+ intermediate_cache2: Optional[torch.Tensor] = None,
is_k_full: bool = True,
+ output: Optional[torch.Tensor] = None,
inplace: bool = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -237,3 +254,124 @@ direct_register_custom_op(
op_func=fused_marlin_moe,
fake_impl=fused_marlin_moe_fake,
)
+
+
+class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
+
+ def __init__(self, quant_config: FusedMoEQuantConfig):
+ # TODO (varun) : Enable activation quantization
+ assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
+ super().__init__(quant_config)
+
+ @override
+ def moe_problem_size(
+ self,
+ a1: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ) -> tuple[int, int, int, int, int]:
+ assert w1.dim() == 3 and w2.dim() == 3
+
+ E = w1.size(0)
+ K = a1.size(-1)
+ N = marlin_moe_intermediate_size(w1, w2)
+
+ if a1.dim() == 2:
+ # Make sure we are using the correct a1 (pre-permute).
+ assert topk_ids.size(0) == a1.size(0), \
+ f"{topk_ids.size(0)} != {a1.size(0)}"
+ M = a1.size(0)
+ else:
+ assert a1.dim() == 3
+ assert a1.size(0) == E, f"{a1.size(0)} == {E}"
+ M = a1.size(1) # This is max_num_tokens
+
+ assert topk_ids.dim() == 2
+ topk = topk_ids.size(1)
+
+ return E, M, N, K, topk
+
+ def supports_expert_map(self) -> bool:
+ return True
+
+ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
+ return TopKWeightAndReduceNoOP()
+
+ @property
+ def activation_formats(
+ self
+ ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ return (mk.FusedMoEActivationFormat.Standard,
+ mk.FusedMoEActivationFormat.Standard)
+
+ def supports_chunking(self) -> bool:
+ return True
+
+ def workspace_shapes(
+ self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
+ topk: int, global_num_experts: int, local_num_experts: int,
+ expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
+ # Modular Kernel provisions output buffer from workspace1. However in
+ # the fused_marlin_moe() function, the final torch.sum(), is defined
+ # essentially as,
+ # `torch.sum(workspace1, dim=1, out=output)`
+ # Having overlapping input and output tensors for torch.sum seems
+ # error prone and depends on how the torch.sum is implemented.
+ # For this reason we swap let the output buffer provision from
+ # workspace2.
+
+ # Workspace/IntermediateCache allocation matching fused_marlin_moe()
+ #workspace1 = (M * topk * max(2 * N, K),)
+ #workspace2 = (M * topk, N)
+
+ # Workspace/IntermediateCache allocation accounting for output buffer
+ # provisioning
+ workspace1 = (M * topk, max(N, K))
+ workspace2 = (M * topk * max(2 * N, K), )
+ output = (M, K)
+
+ return (workspace1, workspace2, output, a.dtype)
+
+ def apply(
+ self,
+ output: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str,
+ global_num_experts: int,
+ expert_map: Optional[torch.Tensor],
+ a1q_scale: Optional[torch.Tensor],
+ a2_scale: Optional[torch.Tensor],
+ workspace13: torch.Tensor,
+ workspace2: torch.Tensor,
+ expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
+ apply_router_weight_on_input: bool,
+ ):
+ assert self.w1_scale is not None
+ assert self.w2_scale is not None
+ return fused_marlin_moe(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ bias1=self.w1_bias,
+ bias2=self.w2_bias,
+ w1_scale=self.w1_scale,
+ w2_scale=self.w2_scale,
+ gating_output=None,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ global_num_experts=global_num_experts,
+ activation=activation,
+ expert_map=expert_map,
+ output=output,
+ # Workspaces are swapped in workspace_shapes() to account for proper
+ # output buffer allocation. Please refer to workspace_shapes().
+ intermediate_cache13=workspace2,
+ intermediate_cache2=workspace13)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 49f278c72007f..f96525734fd9a 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -1780,7 +1780,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
- E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
+ E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids)
if global_num_experts == -1:
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index b6afc8651e36d..a7617f8b7297d 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
#
-def _moe_problem_size(
- a1: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_ids: torch.Tensor,
-) -> tuple[int, int, int, int, int]:
- """
- Extract the MoE problem size from the given tensor arguments:
- - a: The hidden states, input to the MoE layer.
- - w1: The first set of expert weights.
- - w2: The second set of expert weights.
- - topk_ids: The topk ids.
-
- Note: extracting the problem shape from the weight and activation tensors is
- not obvious. It needs to be done this way specifically due to subtle issues
- with particular kernels, e.g. the int4 kernels divide the trailing dimension
- by two, so it's not "correct" to extract N or K from the trailing dimension
- of w1 or w2. Similarly, some kernels transpose the weights, so this needs
- to be kept in mind.
- """
- assert w1.dim() == 3 and w2.dim() == 3
- E, N, _ = w1.size()
- K = a1.size(-1)
-
- if a1.dim() == 2:
- # Make sure we are using the correct a1 (pre-permute).
- assert topk_ids.size(0) == a1.size(0), \
- f"{topk_ids.size(0)} != {a1.size(0)}"
- M = a1.size(0)
- else:
- assert a1.dim() == 3
- assert a1.size(0) == E, f"{a1.size(0)} == {E}"
- M = a1.size(1) # This is max_num_tokens
-
- assert topk_ids.dim() == 2
- topk = topk_ids.size(1)
-
- return E, M, N, K, topk
-
-
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
@@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
+ def moe_problem_size(
+ self,
+ a1: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ) -> tuple[int, int, int, int, int]:
+ """
+ Extract the MoE problem size from the given tensor arguments:
+ - a: The hidden states, input to the MoE layer.
+ - w1: The first set of expert weights.
+ - w2: The second set of expert weights.
+ - topk_ids: The topk ids.
+
+ Note: extracting the problem shape from the weight and activation
+ tensors is not obvious. It needs to be done this way specifically
+ due to subtle issues with particular kernels, e.g. the int4 kernels
+ divide the trailing dimension by two, so it's not "correct" to
+ extract N or K from the trailing dimension of w1 or w2. Similarly,
+ some kernels transpose the weights, so this needs to be kept in mind.
+
+ Note: This implementation covers most cases. However, if experts
+ require a specialized implementation, like MarlinExperts, they are free
+ to override this function.
+ """
+ assert w1.dim() == 3 and w2.dim() == 3
+ E, N, _ = w1.size()
+ K = a1.size(-1)
+
+ if a1.dim() == 2:
+ # Make sure we are using the correct a1 (pre-permute).
+ assert topk_ids.size(0) == a1.size(0), \
+ f"{topk_ids.size(0)} != {a1.size(0)}"
+ M = a1.size(0)
+ else:
+ assert a1.dim() == 3
+ assert a1.size(0) == E, f"{a1.size(0)} == {E}"
+ M = a1.size(1) # This is max_num_tokens
+
+ assert topk_ids.dim() == 2
+ topk = topk_ids.size(1)
+
+ return E, M, N, K, topk
+
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
- _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
+ _, M, N, K, top_k = self.fused_experts.moe_problem_size(
+ a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
@@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
- _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
+ _, M, N, K, top_k = self.fused_experts.moe_problem_size(
+ a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index 54194b2e7d5b0..950bf33dbf01d 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
+from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
@@ -92,7 +93,7 @@ def get_mxfp4_backend():
"Please `pip install vllm[flashinfer]` for best results.")
# If FlashInfer is not available, try either Marlin or Triton
- if current_platform.get_device_capability(
+ if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability(
)[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer(
"2.8.0"):
logger.info_once("Using Marlin backend")
@@ -646,9 +647,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
- return None
-
- if self.mxfp4_backend == Mxfp4Backend.TRITON:
+ return mxfp4_w4a16_moe_quant_config(
+ w1_bias=layer.w13_bias,
+ w2_bias=layer.w2_bias,
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ )
+ elif self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
@@ -690,6 +695,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs)
+ elif (self.mxfp4_backend == Mxfp4Backend.MARLIN):
+ return MarlinExperts(self.moe_quant_config)
else:
return OAITritonExperts(self.moe_quant_config)
@@ -782,6 +789,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
+ if self.fused_experts is not None:
+ return self._route_and_experts(
+ layer,
+ x,
+ router_logits,
+ top_k,
+ renormalize,
+ use_grouped_topk,
+ topk_group,
+ num_expert_group,
+ global_num_experts,
+ expert_map,
+ custom_routing_function,
+ scoring_func,
+ e_score_correction_bias,
+ apply_router_weight_on_input,
+ activation,
+ enable_eplb,
+ expert_load_view,
+ logical_to_physical_map,
+ logical_replica_count,
+ )
+
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
@@ -815,29 +845,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
expert_map=expert_map)
- if self.fused_experts is not None:
- return self._route_and_experts(
- layer,
- x,
- router_logits,
- top_k,
- renormalize,
- use_grouped_topk,
- topk_group,
- num_expert_group,
- global_num_experts,
- expert_map,
- custom_routing_function,
- scoring_func,
- e_score_correction_bias,
- apply_router_weight_on_input,
- activation,
- enable_eplb,
- expert_load_view,
- logical_to_physical_map,
- logical_replica_count,
- )
-
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
index 317ad079b392d..6c7604cc9d048 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
@@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight and supports_activation
+def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
+ w2_packed: torch.Tensor):
+ """
+ Given Marlin packed weight matrices w1_packed, and w2_packed,
+ return the MoE intermediate size N
+ """
+ marlin_tile_size = 16
+ return w2_packed.size(1) * marlin_tile_size
+
+
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //