mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 00:54:50 +08:00
mix_placement
Signed-off-by: Che Ruan <cr623@ic.ac.uk> format fix Signed-off-by: Che Ruan <cr623@ic.ac.uk> format fix Signed-off-by: Che Ruan <cr623@ic.ac.uk> Update deepseek_v2.py Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> format fix Signed-off-by: Che Ruan <cr623@ic.ac.uk> Update deepseek_v2.py Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> format fix Signed-off-by: Che Ruan <cr623@ic.ac.uk>
This commit is contained in:
parent
ddfac7034e
commit
25c203ff87
@ -34,7 +34,7 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, all_reduce
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_ep_group,
|
||||
get_node_count,
|
||||
@ -269,6 +269,8 @@ class EplbState:
|
||||
def build_initial_global_physical_to_logical_map(
|
||||
num_routed_experts: int,
|
||||
num_redundant_experts: int,
|
||||
num_shared_experts: int = 0,
|
||||
mix_placement: bool = False,
|
||||
) -> Sequence[int]:
|
||||
"""
|
||||
Build an initial expert arrangement using the following structure:
|
||||
@ -279,10 +281,25 @@ class EplbState:
|
||||
where each integer is the index of the logical expert
|
||||
that the corresponding physical expert maps to.
|
||||
"""
|
||||
global_physical_to_logical_map = list(range(num_routed_experts))
|
||||
global_physical_to_logical_map += [
|
||||
i % num_routed_experts for i in range(num_redundant_experts)
|
||||
]
|
||||
ep_size = get_ep_group().world_size
|
||||
num_physical_experts = num_routed_experts + num_redundant_experts
|
||||
if mix_placement:
|
||||
num_base_experts = num_physical_experts // ep_size
|
||||
global_physical_to_logical_map = list()
|
||||
for ep_rank in range(ep_size):
|
||||
start_idx = ep_rank * num_base_experts
|
||||
end_idx = (ep_rank + 1) * num_base_experts
|
||||
global_physical_to_logical_map += [
|
||||
i % num_routed_experts for i in range(start_idx, end_idx)
|
||||
]
|
||||
global_physical_to_logical_map += [
|
||||
num_routed_experts + i for i in range(num_shared_experts)
|
||||
]
|
||||
else:
|
||||
global_physical_to_logical_map = list(range(num_routed_experts))
|
||||
global_physical_to_logical_map += [
|
||||
i % num_routed_experts for i in range(num_redundant_experts)
|
||||
]
|
||||
return global_physical_to_logical_map
|
||||
|
||||
def validate_ep_configuration(self, new_model: MixtureOfExperts):
|
||||
@ -334,11 +351,18 @@ class EplbState:
|
||||
"""
|
||||
self.validate_ep_configuration(model)
|
||||
self.is_async = self.parallel_config.eplb_config.use_async
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
mix_placement = additional_config.get("mix_placement", False)
|
||||
else:
|
||||
mix_placement = getattr(additional_config, "mix_placement", False)
|
||||
|
||||
physical_to_logical_map_list = (
|
||||
EplbState.build_initial_global_physical_to_logical_map(
|
||||
model.num_routed_experts,
|
||||
model.num_redundant_experts,
|
||||
model.num_shared_experts,
|
||||
mix_placement,
|
||||
)
|
||||
)
|
||||
physical_to_logical_map = torch.tensor(
|
||||
|
||||
@ -102,6 +102,7 @@ def determine_expert_map(
|
||||
global_num_experts: int,
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear",
|
||||
num_fused_shared_experts: int = 0,
|
||||
mix_placement: bool = False,
|
||||
return_expert_mask: bool = False,
|
||||
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
@ -138,6 +139,8 @@ def determine_expert_map(
|
||||
|
||||
# Distribute experts as evenly as possible to each rank.
|
||||
base_experts = global_num_experts // ep_size
|
||||
if mix_placement:
|
||||
base_experts += num_fused_shared_experts
|
||||
remainder = global_num_experts % ep_size
|
||||
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
|
||||
|
||||
@ -433,7 +436,11 @@ class FusedMoE(CustomOp):
|
||||
self.expert_placement_strategy: ExpertPlacementStrategy = (
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
)
|
||||
|
||||
additional_config = self.vllm_config.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
self.mix_placement = additional_config.get("mix_placement", False)
|
||||
else:
|
||||
self.mix_placement = getattr(additional_config, "mix_placement", False)
|
||||
# ROCm aiter shared experts fusion
|
||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
self.aiter_fmoe_shared_expert_enabled = (
|
||||
@ -442,7 +449,8 @@ class FusedMoE(CustomOp):
|
||||
|
||||
self.num_fused_shared_experts = (
|
||||
n_shared_experts
|
||||
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
|
||||
if n_shared_experts is not None
|
||||
and (self.aiter_fmoe_shared_expert_enabled or self.mix_placement)
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
@ -507,10 +515,10 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=vllm_config, dp_size=dp_size_
|
||||
)
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=vllm_config, dp_size=dp_size_
|
||||
)
|
||||
if self.use_ep and self.rocm_aiter_fmoe_enabled:
|
||||
assert self.expert_mask is None or torch.all(
|
||||
(expert_mask == 0) | (expert_mask == 1)
|
||||
@ -1637,6 +1645,26 @@ class FusedMoE(CustomOp):
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
if self.mix_placement:
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
shared_expert_routing_factor = 1.0
|
||||
batch_size = topk_ids.shape[0]
|
||||
shared_expert_ids = torch.arrange(
|
||||
self.logical_num_experts,
|
||||
self.logical_num_experts + self.num_fused_shared_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
).repeat(batch_size, 1)
|
||||
shared_expert_weights = torch.full(
|
||||
(batch_size, self.num_fused_shared_experts),
|
||||
shared_expert_routing_factor,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_ids = torch.cat([topk_ids, shared_expert_ids], dim=1)
|
||||
topk_weights = torch.cat([topk_weights, shared_expert_weights], dim=1)
|
||||
|
||||
if self.enable_eplb:
|
||||
topk_ids = eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
|
||||
@ -276,21 +276,35 @@ class DeepseekV2MoE(nn.Module):
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
self.mix_placement = additional_config.get("mix_placement", False)
|
||||
else:
|
||||
self.mix_placement = getattr(additional_config, "mix_placement", False)
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
num_physical_routed_experts = self.n_routed_experts + self.n_redundant_experts
|
||||
if self.mix_placement:
|
||||
self.n_logical_experts = self.n_routed_experts + self.n_shared_experts
|
||||
self.n_physical_experts = (
|
||||
num_physical_routed_experts + self.ep_size * self.n_shared_experts
|
||||
)
|
||||
self.n_local_physical_experts = (
|
||||
num_physical_routed_experts // self.ep_size + self.n_shared_experts
|
||||
)
|
||||
else:
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = num_physical_routed_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
|
||||
self.physical_expert_end = (
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
self.is_fusion_moe_shared_experts_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
|
||||
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
self.mix_placement = True
|
||||
if config.n_shared_experts is None or self.mix_placement:
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@ -322,16 +336,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
scoring_func=getattr(config, "scoring_func", "softmax"),
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
if not self.is_rocm_aiter_moe_enabled
|
||||
else self.routed_scaling_factor,
|
||||
routed_scaling_factor=self.routed_scaling_factor
|
||||
if self.is_rocm_aiter_moe_enabled or self.mix_placement
|
||||
else 1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
n_shared_experts=config.n_shared_experts
|
||||
if self.is_fusion_moe_shared_experts_enabled
|
||||
else None,
|
||||
n_shared_experts=config.n_shared_experts if self.mix_placement else None,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -364,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
if not self.is_rocm_aiter_moe_enabled:
|
||||
if not (self.is_rocm_aiter_moe_enabled or self.mix_placement):
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
@ -1393,6 +1405,7 @@ class DeepseekV2ForCausalLM(
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.additional_config = vllm_config.additional_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
@ -1494,9 +1507,16 @@ class DeepseekV2ForCausalLM(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
if isinstance(self.additional_config, dict):
|
||||
mix_placement = self.additional_config.get("mix_placement", False)
|
||||
else:
|
||||
mix_placement = getattr(self.additional_config, "mix_placement", False)
|
||||
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
if rocm_aiter_moe_shared_expert_enabled:
|
||||
mix_placement = True
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@ -1523,11 +1543,7 @@ class DeepseekV2ForCausalLM(
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
+ (self.config.n_shared_experts if mix_placement else 0),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
@ -1541,8 +1557,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fusion_moe_shared_experts_layer = (
|
||||
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
|
||||
is_fusion_moe_shared_experts_layer = mix_placement and (
|
||||
"mlp.shared_experts" in name
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user