mix_placement

Signed-off-by: Che Ruan <cr623@ic.ac.uk>

format fix

Signed-off-by: Che Ruan <cr623@ic.ac.uk>

format fix

Signed-off-by: Che Ruan <cr623@ic.ac.uk>

Update deepseek_v2.py

Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>

format fix

Signed-off-by: Che Ruan <cr623@ic.ac.uk>

Update deepseek_v2.py

Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>

format fix

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
This commit is contained in:
Che Ruan 2025-12-24 10:53:19 +08:00
parent ddfac7034e
commit 25c203ff87
3 changed files with 101 additions and 33 deletions

View File

@ -34,7 +34,7 @@ from dataclasses import dataclass
import torch import torch
from torch.distributed import ProcessGroup, all_reduce from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_ep_group, get_ep_group,
get_node_count, get_node_count,
@ -269,6 +269,8 @@ class EplbState:
def build_initial_global_physical_to_logical_map( def build_initial_global_physical_to_logical_map(
num_routed_experts: int, num_routed_experts: int,
num_redundant_experts: int, num_redundant_experts: int,
num_shared_experts: int = 0,
mix_placement: bool = False,
) -> Sequence[int]: ) -> Sequence[int]:
""" """
Build an initial expert arrangement using the following structure: Build an initial expert arrangement using the following structure:
@ -279,10 +281,25 @@ class EplbState:
where each integer is the index of the logical expert where each integer is the index of the logical expert
that the corresponding physical expert maps to. that the corresponding physical expert maps to.
""" """
global_physical_to_logical_map = list(range(num_routed_experts)) ep_size = get_ep_group().world_size
global_physical_to_logical_map += [ num_physical_experts = num_routed_experts + num_redundant_experts
i % num_routed_experts for i in range(num_redundant_experts) if mix_placement:
] num_base_experts = num_physical_experts // ep_size
global_physical_to_logical_map = list()
for ep_rank in range(ep_size):
start_idx = ep_rank * num_base_experts
end_idx = (ep_rank + 1) * num_base_experts
global_physical_to_logical_map += [
i % num_routed_experts for i in range(start_idx, end_idx)
]
global_physical_to_logical_map += [
num_routed_experts + i for i in range(num_shared_experts)
]
else:
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map return global_physical_to_logical_map
def validate_ep_configuration(self, new_model: MixtureOfExperts): def validate_ep_configuration(self, new_model: MixtureOfExperts):
@ -334,11 +351,18 @@ class EplbState:
""" """
self.validate_ep_configuration(model) self.validate_ep_configuration(model)
self.is_async = self.parallel_config.eplb_config.use_async self.is_async = self.parallel_config.eplb_config.use_async
additional_config = get_current_vllm_config().additional_config
if isinstance(additional_config, dict):
mix_placement = additional_config.get("mix_placement", False)
else:
mix_placement = getattr(additional_config, "mix_placement", False)
physical_to_logical_map_list = ( physical_to_logical_map_list = (
EplbState.build_initial_global_physical_to_logical_map( EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts, model.num_routed_experts,
model.num_redundant_experts, model.num_redundant_experts,
model.num_shared_experts,
mix_placement,
) )
) )
physical_to_logical_map = torch.tensor( physical_to_logical_map = torch.tensor(

View File

@ -102,6 +102,7 @@ def determine_expert_map(
global_num_experts: int, global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear", expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
mix_placement: bool = False,
return_expert_mask: bool = False, return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: ) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
""" """
@ -138,6 +139,8 @@ def determine_expert_map(
# Distribute experts as evenly as possible to each rank. # Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size base_experts = global_num_experts // ep_size
if mix_placement:
base_experts += num_fused_shared_experts
remainder = global_num_experts % ep_size remainder = global_num_experts % ep_size
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
@ -433,7 +436,11 @@ class FusedMoE(CustomOp):
self.expert_placement_strategy: ExpertPlacementStrategy = ( self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy vllm_config.parallel_config.expert_placement_strategy
) )
additional_config = self.vllm_config.additional_config
if isinstance(additional_config, dict):
self.mix_placement = additional_config.get("mix_placement", False)
else:
self.mix_placement = getattr(additional_config, "mix_placement", False)
# ROCm aiter shared experts fusion # ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = ( self.aiter_fmoe_shared_expert_enabled = (
@ -442,7 +449,8 @@ class FusedMoE(CustomOp):
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
n_shared_experts n_shared_experts
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled if n_shared_experts is not None
and (self.aiter_fmoe_shared_expert_enabled or self.mix_placement)
else 0 else 0
) )
if ( if (
@ -507,10 +515,10 @@ class FusedMoE(CustomOp):
) )
self.top_k = top_k self.top_k = top_k
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer( self._init_aiter_shared_experts_topK_buffer(
vllm_config=vllm_config, dp_size=dp_size_ vllm_config=vllm_config, dp_size=dp_size_
) )
if self.use_ep and self.rocm_aiter_fmoe_enabled: if self.use_ep and self.rocm_aiter_fmoe_enabled:
assert self.expert_mask is None or torch.all( assert self.expert_mask is None or torch.all(
(expert_mask == 0) | (expert_mask == 1) (expert_mask == 0) | (expert_mask == 1)
@ -1637,6 +1645,26 @@ class FusedMoE(CustomOp):
renormalize=self.renormalize, renormalize=self.renormalize,
) )
if self.mix_placement:
if self.routed_scaling_factor != 1.0:
topk_weights *= self.routed_scaling_factor
shared_expert_routing_factor = 1.0
batch_size = topk_ids.shape[0]
shared_expert_ids = torch.arrange(
self.logical_num_experts,
self.logical_num_experts + self.num_fused_shared_experts,
dtype=topk_ids.dtype,
device=topk_ids.device,
).repeat(batch_size, 1)
shared_expert_weights = torch.full(
(batch_size, self.num_fused_shared_experts),
shared_expert_routing_factor,
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_ids = torch.cat([topk_ids, shared_expert_ids], dim=1)
topk_weights = torch.cat([topk_weights, shared_expert_weights], dim=1)
if self.enable_eplb: if self.enable_eplb:
topk_ids = eplb_map_to_physical_and_record( topk_ids = eplb_map_to_physical_and_record(
topk_ids=topk_ids, topk_ids=topk_ids,

View File

@ -276,21 +276,35 @@ class DeepseekV2MoE(nn.Module):
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts additional_config = get_current_vllm_config().additional_config
self.n_logical_experts = self.n_routed_experts if isinstance(additional_config, dict):
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.mix_placement = additional_config.get("mix_placement", False)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size else:
self.mix_placement = getattr(additional_config, "mix_placement", False)
self.n_redundant_experts = eplb_config.num_redundant_experts
num_physical_routed_experts = self.n_routed_experts + self.n_redundant_experts
if self.mix_placement:
self.n_logical_experts = self.n_routed_experts + self.n_shared_experts
self.n_physical_experts = (
num_physical_routed_experts + self.ep_size * self.n_shared_experts
)
self.n_local_physical_experts = (
num_physical_routed_experts // self.ep_size + self.n_shared_experts
)
else:
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = num_physical_routed_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = ( self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts self.physical_expert_start + self.n_local_physical_experts
) )
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.is_fusion_moe_shared_experts_enabled = ( if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() self.mix_placement = True
) if config.n_shared_experts is None or self.mix_placement:
if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
self.shared_experts = None self.shared_experts = None
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@ -322,16 +336,14 @@ class DeepseekV2MoE(nn.Module):
scoring_func=getattr(config, "scoring_func", "softmax"), scoring_func=getattr(config, "scoring_func", "softmax"),
# we do scaling outside, set factor to 1.0 to avoid double mul # we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally # aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0 routed_scaling_factor=self.routed_scaling_factor
if not self.is_rocm_aiter_moe_enabled if self.is_rocm_aiter_moe_enabled or self.mix_placement
else self.routed_scaling_factor, else 1.0,
e_score_correction_bias=self.gate.e_score_correction_bias, e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts n_shared_experts=config.n_shared_experts if self.mix_placement else None,
if self.is_fusion_moe_shared_experts_enabled
else None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -364,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled: if not (self.is_rocm_aiter_moe_enabled or self.mix_placement):
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None: elif self.shared_experts is not None:
assert shared_output is not None assert shared_output is not None
@ -1393,6 +1405,7 @@ class DeepseekV2ForCausalLM(
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.additional_config = vllm_config.additional_config
self.quant_config = quant_config self.quant_config = quant_config
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
@ -1494,9 +1507,16 @@ class DeepseekV2ForCausalLM(
) )
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if isinstance(self.additional_config, dict):
mix_placement = self.additional_config.get("mix_placement", False)
else:
mix_placement = getattr(self.additional_config, "mix_placement", False)
rocm_aiter_moe_shared_expert_enabled = ( rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
) )
if rocm_aiter_moe_shared_expert_enabled:
mix_placement = True
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
@ -1523,11 +1543,7 @@ class DeepseekV2ForCausalLM(
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts num_experts=self.config.n_routed_experts
+ ( + (self.config.n_shared_experts if mix_placement else 0),
self.config.n_shared_experts
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts, num_redundant_experts=self.num_redundant_experts,
) )
@ -1541,8 +1557,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None: if spec_layer is not None:
continue # skip spec decode layers for main model continue # skip spec decode layers for main model
is_fusion_moe_shared_experts_layer = ( is_fusion_moe_shared_experts_layer = mix_placement and (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) "mlp.shared_experts" in name
) )
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping: