mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:45:01 +08:00
[Model][gpt-oss] Support DP+EP for GPT-OSS with FlashInfer trtllm-gen MoE (#23819)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
1f096f9b95
commit
95089607fa
@ -190,12 +190,6 @@ class FusedMoEParallelConfig:
|
|||||||
return (self.use_all2all_kernels
|
return (self.use_all2all_kernels
|
||||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||||
|
|
||||||
@property
|
|
||||||
def use_flashinfer_cutlass_kernels(self):
|
|
||||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
|
||||||
and has_flashinfer_cutlass_fused_moe()
|
|
||||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(tp_size_: int, dp_size_: int,
|
def make(tp_size_: int, dp_size_: int,
|
||||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||||
@ -404,7 +398,14 @@ class FusedMoEConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_flashinfer_cutlass_kernels(self):
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
"""
|
||||||
|
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||||
|
"""
|
||||||
|
return (self.quant_config is not None
|
||||||
|
and self.quant_config.quant_dtype == "nvfp4"
|
||||||
|
and envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||||
|
and has_flashinfer_cutlass_fused_moe()
|
||||||
|
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
|
|||||||
@ -920,7 +920,7 @@ class FusedMoE(CustomOp):
|
|||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
|
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(moe.max_num_tokens, self.hidden_size),
|
(moe.max_num_tokens, self.hidden_size),
|
||||||
dtype=moe.in_dtype,
|
dtype=moe.in_dtype,
|
||||||
@ -974,7 +974,7 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_flashinfer_cutlass_kernels(self):
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
return self.moe_config.use_flashinfer_cutlass_kernels
|
||||||
|
|
||||||
def update_expert_map(self):
|
def update_expert_map(self):
|
||||||
# ep_size and ep_rank should already be updated
|
# ep_size and ep_rank should already be updated
|
||||||
@ -1665,7 +1665,7 @@ class FusedMoE(CustomOp):
|
|||||||
# only when data parallelism (DP) is enabled.
|
# only when data parallelism (DP) is enabled.
|
||||||
use_flashinfer_cutlass_kernels = (
|
use_flashinfer_cutlass_kernels = (
|
||||||
self.dp_size > 1
|
self.dp_size > 1
|
||||||
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
or use_flashinfer_cutlass_kernels):
|
or use_flashinfer_cutlass_kernels):
|
||||||
@ -1674,7 +1674,7 @@ class FusedMoE(CustomOp):
|
|||||||
do_naive_dispatch_combine: bool = (
|
do_naive_dispatch_combine: bool = (
|
||||||
self.dp_size > 1
|
self.dp_size > 1
|
||||||
and not self.moe_parallel_config.use_deepep_ht_kernels
|
and not self.moe_parallel_config.use_deepep_ht_kernels
|
||||||
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
and not self.moe_config.use_flashinfer_cutlass_kernels)
|
||||||
if do_naive_dispatch_combine:
|
if do_naive_dispatch_combine:
|
||||||
hidden_states, router_logits = get_ep_group().dispatch(
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
hidden_states, router_logits)
|
hidden_states, router_logits)
|
||||||
|
|||||||
@ -623,8 +623,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
if should_use_flashinfer_mxfp4():
|
if should_use_flashinfer_mxfp4():
|
||||||
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
||||||
assert not self.moe.use_ep, (
|
|
||||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
|
||||||
if _should_use_flashinfer_mxfp4_bf16():
|
if _should_use_flashinfer_mxfp4_bf16():
|
||||||
assert x.dtype == torch.bfloat16
|
assert x.dtype == torch.bfloat16
|
||||||
x_quant = x
|
x_quant = x
|
||||||
@ -650,12 +648,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
None, # output1_scale_scalar
|
None, # output1_scale_scalar
|
||||||
None, # output1_scale_gate_scalar
|
None, # output1_scale_gate_scalar
|
||||||
None, # output2_scale_scalar
|
None, # output2_scale_scalar
|
||||||
self.num_experts,
|
global_num_experts,
|
||||||
top_k,
|
top_k,
|
||||||
None, # n_group
|
None, # n_group
|
||||||
None, # topk_group
|
None, # topk_group
|
||||||
self.intermediate_size, # padded to multiple of 256
|
self.intermediate_size, # padded to multiple of 256
|
||||||
0, # local_expert_offset
|
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
||||||
self.num_experts, # local num experts
|
self.num_experts, # local num experts
|
||||||
None,
|
None,
|
||||||
self._get_tile_tokens_dim(x, top_k),
|
self._get_tile_tokens_dim(x, top_k),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user