From 63b22e0dbb901b75619aa4bca2dfa1d7a71f439e Mon Sep 17 00:00:00 2001 From: CSWYF3634076 Date: Mon, 27 Oct 2025 11:53:31 +0800 Subject: [PATCH] [Model][Bugfix] fix ernie45 moe 300B SharedFusedMoE output tuple (#27316) Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 607589e68ef3..192ca0585230 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -215,6 +215,8 @@ class Ernie4_5_MoeMoE(nn.Module): if self.has_shared_experts: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + else: + final_hidden_states = final_hidden_states[1] if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(