diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index ace7e333e2137..d002d1838c8ea 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -341,7 +341,10 @@ class Ernie4_5_VLMoeMoE(nn.Module): # text and vision modals input visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask - final_hidden_states = torch.zeros_like(hidden_states) + final_experts_hidden_states = torch.zeros_like(hidden_states) + final_shared_ouput = ( + torch.zeros_like(hidden_states) if self.has_shared_experts else None + ) text_hidden_states = hidden_states[text_token_mask].reshape( -1, self.hidden_size @@ -353,16 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module): text_router_logits, _ = self.text_experts_gate( text_hidden_states.to(dtype=torch.float32) ) - final_hidden_states[text_token_mask] = self.text_experts( + text_shared_ouput, text_experts_output = self.text_experts( hidden_states=text_hidden_states, router_logits=text_router_logits - ).flatten() + ) + final_experts_hidden_states[text_token_mask] = text_experts_output.flatten() + if self.has_shared_experts: + final_shared_ouput[text_token_mask] = text_shared_ouput.flatten() vision_router_logits, _ = self.vision_experts_gate( vision_hidden_states.to(dtype=torch.float32) ) - final_hidden_states[visual_token_mask] = self.vision_experts( + vision_shared_ouput, vision_experts_output = self.vision_experts( hidden_states=vision_hidden_states, router_logits=vision_router_logits - ).flatten() + ) + final_experts_hidden_states[visual_token_mask] = ( + vision_experts_output.flatten() + ) + if self.has_shared_experts: + final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten() + + final_hidden_states = (final_shared_ouput, final_experts_hidden_states) else: # only text modal input text_router_logits, _ = self.text_experts_gate( @@ -374,7 +387,11 @@ class Ernie4_5_VLMoeMoE(nn.Module): ) if self.has_shared_experts: + # for shared_experts model final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + else: + # for not shared_experts model + final_hidden_states = final_hidden_states[1] if self.tp_size > 1: final_hidden_states = (