mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 13:24:31 +08:00
[Bugfix] Fix dp_chunking enablement logic in FusedMoE layer (#27220)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
295c7f0267
commit
9ef3d5b875
@ -1102,6 +1102,7 @@ class FusedMoE(CustomOp):
|
|||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
# FIXME (varun): We should have a better way of inferring the activation
|
# FIXME (varun): We should have a better way of inferring the activation
|
||||||
# datatype. This works for now as the tensor datatype entering the MoE
|
# datatype. This works for now as the tensor datatype entering the MoE
|
||||||
@ -1342,26 +1343,6 @@ class FusedMoE(CustomOp):
|
|||||||
self.batched_hidden_states: torch.Tensor | None = None
|
self.batched_hidden_states: torch.Tensor | None = None
|
||||||
self.batched_router_logits: torch.Tensor | None = None
|
self.batched_router_logits: torch.Tensor | None = None
|
||||||
|
|
||||||
if self.use_dp_chunking:
|
|
||||||
states_shape: tuple[int, ...]
|
|
||||||
logits_shape: tuple[int, ...]
|
|
||||||
|
|
||||||
# Note here we use `num_experts` which is logical expert count
|
|
||||||
if vllm_config.parallel_config.enable_dbo:
|
|
||||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
|
||||||
logits_shape = (2, moe.max_num_tokens, num_experts)
|
|
||||||
else:
|
|
||||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
|
||||||
logits_shape = (moe.max_num_tokens, num_experts)
|
|
||||||
|
|
||||||
self.batched_hidden_states = torch.zeros(
|
|
||||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.batched_router_logits = torch.zeros(
|
|
||||||
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shared_experts(self) -> torch.nn.Module | None:
|
def shared_experts(self) -> torch.nn.Module | None:
|
||||||
return None
|
return None
|
||||||
@ -1420,8 +1401,6 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_dp_chunking(self) -> bool:
|
def use_dp_chunking(self) -> bool:
|
||||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
|
||||||
# only when data parallelism (DP) is enabled.
|
|
||||||
return (
|
return (
|
||||||
self.moe_parallel_config.use_pplx_kernels
|
self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
@ -1988,12 +1967,40 @@ class FusedMoE(CustomOp):
|
|||||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||||
|
|
||||||
def ensure_moe_quant_config(self):
|
def ensure_moe_quant_config_init(self):
|
||||||
if self.quant_method.moe_quant_config is None:
|
if self.quant_method.moe_quant_config is None:
|
||||||
self.quant_method.moe_quant_config = (
|
self.quant_method.moe_quant_config = (
|
||||||
self.quant_method.get_fused_moe_quant_config(self)
|
self.quant_method.get_fused_moe_quant_config(self)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.moe_quant_config is None:
|
||||||
|
self.moe_quant_config = self.quant_method.moe_quant_config
|
||||||
|
|
||||||
|
def ensure_dp_chunking_init(self):
|
||||||
|
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
states_shape: tuple[int, ...]
|
||||||
|
logits_shape: tuple[int, ...]
|
||||||
|
|
||||||
|
moe = self.moe_config
|
||||||
|
|
||||||
|
# Note here we use `num_experts` which is logical expert count
|
||||||
|
if self.vllm_config.parallel_config.enable_dbo:
|
||||||
|
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||||
|
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
|
||||||
|
else:
|
||||||
|
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||||
|
logits_shape = (moe.max_num_tokens, moe.num_experts)
|
||||||
|
|
||||||
|
self.batched_hidden_states = torch.zeros(
|
||||||
|
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.batched_router_logits = torch.zeros(
|
||||||
|
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_experts(
|
def select_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -2224,8 +2231,6 @@ class FusedMoE(CustomOp):
|
|||||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||||
|
|
||||||
self.ensure_moe_quant_config()
|
|
||||||
|
|
||||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||||
@ -2383,7 +2388,8 @@ class FusedMoE(CustomOp):
|
|||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
self.ensure_moe_quant_config()
|
self.ensure_moe_quant_config_init()
|
||||||
|
self.ensure_dp_chunking_init()
|
||||||
|
|
||||||
has_separate_shared_experts = (
|
has_separate_shared_experts = (
|
||||||
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user