mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 10:27:10 +08:00
[Model] Removes redundant all-reduce operation in Qwen3MoeSparseMoeBlock (#23169)
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
parent
1298c67795
commit
4f510bc2a1
@ -139,7 +139,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
reduce_results=True,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
@ -163,10 +163,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user