[Model][gpt-oss] Support DP+EP for GPT-OSS with FlashInfer trtllm-gen MoE (#23819)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA) 2025-08-28 21:56:20 +08:00 committed by GitHub
parent 1f096f9b95
commit 95089607fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 15 deletions

View File

@ -190,12 +190,6 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
@ -404,7 +398,14 @@ class FusedMoEConfig:
@property @property
def use_flashinfer_cutlass_kernels(self): def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels """
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
@staticmethod @staticmethod
def make( def make(

View File

@ -920,7 +920,7 @@ class FusedMoE(CustomOp):
self.batched_router_logits: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_flashinfer_cutlass_kernels): or self.moe_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros( self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size), (moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype, dtype=moe.in_dtype,
@ -974,7 +974,7 @@ class FusedMoE(CustomOp):
@property @property
def use_flashinfer_cutlass_kernels(self): def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels return self.moe_config.use_flashinfer_cutlass_kernels
def update_expert_map(self): def update_expert_map(self):
# ep_size and ep_rank should already be updated # ep_size and ep_rank should already be updated
@ -1665,7 +1665,7 @@ class FusedMoE(CustomOp):
# only when data parallelism (DP) is enabled. # only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = ( use_flashinfer_cutlass_kernels = (
self.dp_size > 1 self.dp_size > 1
and self.moe_parallel_config.use_flashinfer_cutlass_kernels) and self.moe_config.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels): or use_flashinfer_cutlass_kernels):
@ -1674,7 +1674,7 @@ class FusedMoE(CustomOp):
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) and not self.moe_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)

View File

@ -623,8 +623,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if should_use_flashinfer_mxfp4(): if should_use_flashinfer_mxfp4():
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if _should_use_flashinfer_mxfp4_bf16(): if _should_use_flashinfer_mxfp4_bf16():
assert x.dtype == torch.bfloat16 assert x.dtype == torch.bfloat16
x_quant = x x_quant = x
@ -650,12 +648,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar None, # output1_scale_scalar
None, # output1_scale_gate_scalar None, # output1_scale_gate_scalar
None, # output2_scale_scalar None, # output2_scale_scalar
self.num_experts, global_num_experts,
top_k, top_k,
None, # n_group None, # n_group
None, # topk_group None, # topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts self.num_experts, # local num experts
None, None,
self._get_tile_tokens_dim(x, top_k), self._get_tile_tokens_dim(x, top_k),