mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:36:02 +08:00
[Model] Apply SharedFusedMoE to glm4_moe. (#24849)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
parent
4a9375fe9d
commit
c15309a730
@ -46,6 +46,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
@ -146,25 +147,6 @@ class Glm4MoE(nn.Module):
|
|||||||
self.physical_expert_end = (self.physical_expert_start +
|
self.physical_expert_end = (self.physical_expert_start +
|
||||||
self.n_local_physical_experts)
|
self.n_local_physical_experts)
|
||||||
|
|
||||||
self.experts = FusedMoE(
|
|
||||||
num_experts=config.n_routed_experts,
|
|
||||||
top_k=config.num_experts_per_tok,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
intermediate_size=config.moe_intermediate_size,
|
|
||||||
reduce_results=False,
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
quant_config=quant_config,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
num_expert_group=config.n_group,
|
|
||||||
topk_group=config.topk_group,
|
|
||||||
prefix=f"{prefix}.experts",
|
|
||||||
scoring_func="sigmoid",
|
|
||||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
|
||||||
routed_scaling_factor=1.0,
|
|
||||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
|
||||||
enable_eplb=self.enable_eplb,
|
|
||||||
num_redundant_experts=self.n_redundant_experts)
|
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = (config.moe_intermediate_size *
|
intermediate_size = (config.moe_intermediate_size *
|
||||||
config.n_shared_experts)
|
config.n_shared_experts)
|
||||||
@ -173,25 +155,68 @@ class Glm4MoE(nn.Module):
|
|||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
reduce_results=False,
|
||||||
),
|
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
)
|
)
|
||||||
|
self.experts = SharedFusedMoE(
|
||||||
|
shared_experts=self.shared_experts,
|
||||||
|
num_experts=config.n_routed_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
scoring_func="sigmoid",
|
||||||
|
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=config.n_routed_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
scoring_func="sigmoid",
|
||||||
|
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||||
|
routed_scaling_factor=1.0,
|
||||||
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
enable_eplb=self.enable_eplb,
|
||||||
|
num_redundant_experts=self.n_redundant_experts)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
# router_logits: (num_tokens, n_experts)
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
else:
|
|
||||||
shared_output = None
|
|
||||||
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
router_logits = self.gate(hidden_states.to(dtype=torch.float32))
|
||||||
final_hidden_states = self.experts(
|
|
||||||
hidden_states=hidden_states,
|
fused_moe_out = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits) * self.routed_scaling_factor
|
router_logits=router_logits)
|
||||||
if shared_output is not None:
|
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
if self.shared_experts is not None:
|
||||||
|
shared_output, final_hidden_states = fused_moe_out
|
||||||
|
assert shared_output is not None
|
||||||
|
final_hidden_states = \
|
||||||
|
final_hidden_states * self.routed_scaling_factor\
|
||||||
|
+ shared_output
|
||||||
|
else:
|
||||||
|
final_hidden_states = fused_moe_out * self.routed_scaling_factor
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user