mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 12:09:38 +08:00
[Bugfix] Fix dp_chunking enablement logic in FusedMoE layer (#27220)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
295c7f0267
commit
9ef3d5b875
@ -1102,6 +1102,7 @@ class FusedMoE(CustomOp):
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# FIXME (varun): We should have a better way of inferring the activation
|
||||
# datatype. This works for now as the tensor datatype entering the MoE
|
||||
@ -1342,26 +1343,6 @@ class FusedMoE(CustomOp):
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
|
||||
if self.use_dp_chunking:
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
if vllm_config.parallel_config.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (2, moe.max_num_tokens, num_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (moe.max_num_tokens, num_experts)
|
||||
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@property
|
||||
def shared_experts(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
@ -1420,8 +1401,6 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
return (
|
||||
self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
@ -1988,12 +1967,40 @@ class FusedMoE(CustomOp):
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
def ensure_moe_quant_config(self):
|
||||
def ensure_moe_quant_config_init(self):
|
||||
if self.quant_method.moe_quant_config is None:
|
||||
self.quant_method.moe_quant_config = (
|
||||
self.quant_method.get_fused_moe_quant_config(self)
|
||||
)
|
||||
|
||||
if self.moe_quant_config is None:
|
||||
self.moe_quant_config = self.quant_method.moe_quant_config
|
||||
|
||||
def ensure_dp_chunking_init(self):
|
||||
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||
return
|
||||
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
if self.vllm_config.parallel_config.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (moe.max_num_tokens, moe.num_experts)
|
||||
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -2224,8 +2231,6 @@ class FusedMoE(CustomOp):
|
||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
@ -2383,7 +2388,8 @@ class FusedMoE(CustomOp):
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
self.ensure_moe_quant_config_init()
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user