[DP/EP][GPTOSS] Use triton matmul-ogs kernels for GPTOSS DP/EP (#24588)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-09-23 00:01:09 -04:00 committed by GitHub
parent fafbe11af4
commit e8db44f883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 275 additions and 76 deletions

View File

@ -288,7 +288,11 @@ class FusedMoEQuantConfig:
@property @property
def use_mxfp4_w4a4(self) -> bool: def use_mxfp4_w4a4(self) -> bool:
return self.quant_dtype == "mxfp4" return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
@property
def use_mxfp4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
@property @property
def use_nvfp4_w4a4(self) -> bool: def use_nvfp4_w4a4(self) -> bool:
@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config(
) )
def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(),
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def mxfp4_w4a4_moe_quant_config( def mxfp4_w4a4_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.utils import round_up
class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Prepare/Finalize using DeepEP High-Throughput kernels. Prepare/Finalize using DeepEP High-Throughput kernels.
""" """
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int,
dtype: torch.dtype) -> int:
# Round up hidden size so it is compatible with DeepEP High Throughput
# kernels.
# DeepEP intranode kernels make copies in units of,
# 32(warp-size) int4 elements. Round up hidden size to respect this.
# For example, an input hidden size of 2880 with dtype torch.bfloat16
# will be rounded up to 3072.
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize
def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
dp_size: int, rank_expert_offset: int): dp_size: int, rank_expert_offset: int):
super().__init__() super().__init__()

View File

@ -9,7 +9,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceNoOP)
from vllm.triton_utils import tl, triton
from vllm.utils import has_triton_kernels from vllm.utils import has_triton_kernels
logger = init_logger(__name__) logger = init_logger(__name__)
@ -19,13 +20,55 @@ if has_triton_kernels():
import triton_kernels.swiglu import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs) matmul_ogs)
from triton_kernels.routing import routing from triton_kernels.routing import (RoutingData, routing,
routing_from_bitmatrix)
from triton_kernels.tensor import Bitmatrix
except (ModuleNotFoundError, AttributeError) as e: except (ModuleNotFoundError, AttributeError) as e:
logger.error( logger.error(
"Failed to import Triton kernels. Please make sure your triton " "Failed to import Triton kernels. Please make sure your triton "
"version is compatible. Error: %s", e) "version is compatible. Error: %s", e)
@triton.jit
def pack_bitmatrix(
bitmatrix,
topk_ids,
n_rows, # n_rows in bitmatrix / topk_ids
bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix
n_expts_act, # num_topk
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
Packs topk_ids into a bitmatrix.
code reference:
https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264
"""
pid_m = tl.program_id(0)
offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offsets_k = tl.arange(0, BLOCK_SIZE_K)
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
div = indices // 32
rem = indices % 32
one = tl.cast(1, tl.uint32)
# Iterate through all the relevant bitmatrix columns.
for i in range(bm_cols):
# When BLOCK_SIZE_K=32, offs is just the column index.
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
x = tl.where(div[:, :, None] == offs[None, None, :],
(one << rem)[:, :, None], 0)
# Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1)
bitmatrix_ptrs = bitmatrix + offsets_m[:,
None] * bm_cols + offs[None, :]
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
def triton_kernel_moe_forward( def triton_kernel_moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor w1, # Tensor or triton_kernels.Tensor
@ -124,34 +167,88 @@ def triton_kernel_fused_experts(
return intermediate_cache3 return intermediate_cache3
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
def __init__( topk_ids = topk_ids.to(torch.int16)
self, topk_weights = topk_weights.to(torch.bfloat16)
max_num_tokens: int,
num_dispatchers: int, n_rows, num_topk = topk_ids.size()
quant_config: FusedMoEQuantConfig,
): BLOCK_SIZE_M = 512
BLOCK_SIZE_K = 32
bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks
bitmatrix = torch.zeros((n_rows, bm_cols),
dtype=torch.uint32,
device=topk_ids.device)
grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), )
pack_bitmatrix[grid](
bitmatrix,
topk_ids,
n_rows,
bm_cols,
num_topk,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
bitmatrix_shape = [n_rows, bm_cols * 32]
bitmatrix_shape_max = [n_rows, None]
bitmatrix = Bitmatrix(bitmatrix,
shape=bitmatrix_shape,
shape_max=bitmatrix_shape_max,
scratchpad=None)
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
routing_data, gather_indx, scatter_indx = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk)
return routing_data, gather_indx, scatter_indx
class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
def _make_routing_data(
self,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
num_local_experts: int,
) -> tuple["RoutingData", torch.Tensor, torch.Tensor]:
return make_routing_data(topk_ids, topk_weights, num_local_experts)
class OAITritonExperts(BaseOAITritonExperts):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
super().__init__(quant_config) super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
@property @property
def activation_formats( def activation_formats(
self self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts, return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.BatchedExperts) mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return False return True
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
@ -159,13 +256,10 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata] expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# workspace are allocated inside the kernel # workspace are allocated inside the kernel
assert a.dim() == 2 workspace1 = (M, K)
num_dp = self.num_dispatchers workspace2 = (0, 0)
num_experts = local_num_experts output = (M, K)
max_num_tokens = self.max_num_tokens return (workspace1, workspace2, output, a.dtype)
workspace2 = (0, 0, 0)
output = (num_experts, max_num_tokens * num_dp, N)
return (output, workspace2, output, a.dtype)
def apply( def apply(
self, self,
@ -185,17 +279,29 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
): ):
return triton_kernel_fused_experts( if expert_map is not None:
output, topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
routing_data, gather_indx, scatter_indx = self._make_routing_data(
topk_ids, topk_weights, local_num_experts)
experts_output = triton_kernel_fused_experts(
None,
hidden_states, hidden_states,
w1, w1,
w2, w2,
routing_data=None, routing_data,
gather_indx=None, gather_indx,
scatter_indx=None, scatter_indx,
activation=activation, activation=activation,
quant_config=self.quant_config, quant_config=self.quant_config,
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
global_num_experts=global_num_experts, global_num_experts=local_num_experts,
expert_map=expert_map, expert_map=None, # applied already
a1q_scale=a1q_scale) a1q_scale=a1q_scale)
output.copy_(experts_output, non_blocking=True)

View File

@ -800,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
for local_index, global_index in zip(local_indices, global_indices)) for local_index, global_index in zip(local_indices, global_indices))
def maybe_roundup_hidden_size(
hidden_size: int, act_dtype: torch.dtype,
quant_config: Optional[QuantizationConfig],
moe_parallel_config: FusedMoEParallelConfig) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size(int): Layer hidden-size
act_dtype: Data type of the layer activations.
quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration.
moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization
strategy configuration.
Return:
Rounded up hidden_size if rounding up is required based on the configs.
Original hidden size otherwise.
"""
if (moe_parallel_config.use_deepep_ht_kernels):
hidden_size = (
DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size, act_dtype))
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
return hidden_size
@CustomOp.register("fused_moe") @CustomOp.register("fused_moe")
class FusedMoE(CustomOp): class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
@ -856,6 +899,18 @@ class FusedMoE(CustomOp):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
# operation is typically unquantized (i.e. float16/bfloat16).
if vllm_config.model_config is not None:
moe_in_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
dp_size_ = (dp_size dp_size_ = (dp_size
@ -865,7 +920,6 @@ class FusedMoE(CustomOp):
if self.is_sequence_parallel: if self.is_sequence_parallel:
self.sp_size = tp_size_ self.sp_size = tp_size_
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
tp_size_=tp_size_, tp_size_=tp_size_,
@ -874,19 +928,10 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts self.global_num_experts = num_experts + num_redundant_experts
# we are padding globally so EP buffer allocation works # Round up hidden size if needed.
if quant_config and quant_config.get_name() == "mxfp4": hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
from vllm.model_executor.layers.quantization.mxfp4 import ( quant_config,
Mxfp4Backend, get_mxfp4_backend) self.moe_parallel_config)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
@ -967,20 +1012,13 @@ class FusedMoE(CustomOp):
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.") "non-grouped topk.")
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = FusedMoEConfig( moe = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
hidden_dim=hidden_size, hidden_dim=hidden_size,
num_local_experts=self.local_num_experts, num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
in_dtype=model_dtype, in_dtype=moe_in_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias, has_bias=has_bias,
) )

View File

@ -76,7 +76,7 @@ def _moe_problem_size(
""" """
assert w1.dim() == 3 and w2.dim() == 3 assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size() E, N, _ = w1.size()
K = w2.size(1) K = a1.size(-1)
if a1.dim() == 2: if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute). # Make sure we are using the correct a1 (pre-permute).

View File

@ -13,7 +13,10 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config) FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config,
mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
@ -578,9 +581,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size # Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# only apply to batched mode # (stored in self.fused_experts) to determine if the MoE has a
if self.moe.use_ep: # batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = (self.moe.use_pplx_kernels
or self.moe.use_deepep_ll_kernels)
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else: else:
num_warps = 8 num_warps = 8
@ -640,16 +648,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if self.mxfp4_backend == Mxfp4Backend.TRITON: if self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
else: else:
w1_scale = layer.w13_weight_scale w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale w2_scale = layer.w2_weight_scale
return mxfp4_w4a4_moe_quant_config(
return mxfp4_w4a4_moe_quant_config( w1_bias=layer.w13_bias,
w1_bias=layer.w13_bias, w2_bias=layer.w2_bias,
w2_bias=layer.w2_bias, w1_scale=w1_scale,
w1_scale=w1_scale, w2_scale=w2_scale,
w2_scale=w2_scale, )
)
def select_gemm_impl( def select_gemm_impl(
self, self,
@ -661,6 +674,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
raise NotImplementedError( raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP") "Mxfp4 does not support batched experts format for EP")
else: else:
assert self.moe_quant_config is not None
if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
# B200 code-path # B200 code-path
@ -671,13 +685,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# TODO(bnell): part of quant_config # TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size, "max_capture_size": self.max_capture_size,
} }
assert self.moe_quant_config is not None
return TrtLlmGenExperts(self.moe, self.moe_quant_config, return TrtLlmGenExperts(self.moe, self.moe_quant_config,
**kwargs) **kwargs)
else: else:
# Use matmul_ogs from triton_kernels here! return OAITritonExperts(self.moe_quant_config)
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
def _route_and_experts( def _route_and_experts(
self, self,
@ -722,10 +733,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logical_to_physical_map=logical_to_physical_map, logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count) logical_replica_count=logical_replica_count)
w13_weight = (self.w13_weight_triton_tensor
if layer.w13_weight is None else layer.w13_weight)
w2_weight = (self.w2_weight_triton_tensor
if layer.w2_weight is None else layer.w2_weight)
assert all([w is not None for w in [w13_weight, w2_weight]])
return self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=w13_weight,
w2=layer.w2_weight, w2=w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,