[Model][Bugfix] fix ernie45 vl run failed from shared experts optimization (#26885)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076 2025-10-16 18:37:35 +08:00 committed by GitHub
parent d2740fafbf
commit e51928793e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -341,7 +341,10 @@ class Ernie4_5_VLMoeMoE(nn.Module):
# text and vision modals input
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
final_hidden_states = torch.zeros_like(hidden_states)
final_experts_hidden_states = torch.zeros_like(hidden_states)
final_shared_ouput = (
torch.zeros_like(hidden_states) if self.has_shared_experts else None
)
text_hidden_states = hidden_states[text_token_mask].reshape(
-1, self.hidden_size
@ -353,16 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module):
text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32)
)
final_hidden_states[text_token_mask] = self.text_experts(
text_shared_ouput, text_experts_output = self.text_experts(
hidden_states=text_hidden_states, router_logits=text_router_logits
).flatten()
)
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
if self.has_shared_experts:
final_shared_ouput[text_token_mask] = text_shared_ouput.flatten()
vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states.to(dtype=torch.float32)
)
final_hidden_states[visual_token_mask] = self.vision_experts(
vision_shared_ouput, vision_experts_output = self.vision_experts(
hidden_states=vision_hidden_states, router_logits=vision_router_logits
).flatten()
)
final_experts_hidden_states[visual_token_mask] = (
vision_experts_output.flatten()
)
if self.has_shared_experts:
final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten()
final_hidden_states = (final_shared_ouput, final_experts_hidden_states)
else:
# only text modal input
text_router_logits, _ = self.text_experts_gate(
@ -374,7 +387,11 @@ class Ernie4_5_VLMoeMoE(nn.Module):
)
if self.has_shared_experts:
# for shared_experts model
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
else:
# for not shared_experts model
final_hidden_states = final_hidden_states[1]
if self.tp_size > 1:
final_hidden_states = (