[Bugfix] Fix dp_chunking enablement logic in FusedMoE layer (#27220)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-10-23 12:03:14 -04:00 committed by GitHub
parent 295c7f0267
commit 9ef3d5b875
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1102,6 +1102,7 @@ class FusedMoE(CustomOp):
self.params_dtype = params_dtype
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
# FIXME (varun): We should have a better way of inferring the activation
# datatype. This works for now as the tensor datatype entering the MoE
@ -1342,26 +1343,6 @@ class FusedMoE(CustomOp):
self.batched_hidden_states: torch.Tensor | None = None
self.batched_router_logits: torch.Tensor | None = None
if self.use_dp_chunking:
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
# Note here we use `num_experts` which is logical expert count
if vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, num_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
@property
def shared_experts(self) -> torch.nn.Module | None:
return None
@ -1420,8 +1401,6 @@ class FusedMoE(CustomOp):
@property
def use_dp_chunking(self) -> bool:
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
@ -1988,12 +1967,40 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx]
def ensure_moe_quant_config(self):
def ensure_moe_quant_config_init(self):
if self.quant_method.moe_quant_config is None:
self.quant_method.moe_quant_config = (
self.quant_method.get_fused_moe_quant_config(self)
)
if self.moe_quant_config is None:
self.moe_quant_config = self.quant_method.moe_quant_config
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
return
states_shape: tuple[int, ...]
logits_shape: tuple[int, ...]
moe = self.moe_config
# Note here we use `num_experts` which is logical expert count
if self.vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, moe.num_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
)
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
@ -2224,8 +2231,6 @@ class FusedMoE(CustomOp):
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
self.ensure_moe_quant_config()
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
@ -2383,7 +2388,8 @@ class FusedMoE(CustomOp):
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
self.ensure_moe_quant_config()
self.ensure_moe_quant_config_init()
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)