mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:34:27 +08:00
[Model][Bugfix] fix ernie45 vl run failed from shared experts optimization (#26885)
Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
parent
d2740fafbf
commit
e51928793e
@ -341,7 +341,10 @@ class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
# text and vision modals input
|
||||
visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool()
|
||||
text_token_mask = ~visual_token_mask
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_experts_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_shared_ouput = (
|
||||
torch.zeros_like(hidden_states) if self.has_shared_experts else None
|
||||
)
|
||||
|
||||
text_hidden_states = hidden_states[text_token_mask].reshape(
|
||||
-1, self.hidden_size
|
||||
@ -353,16 +356,26 @@ class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
text_router_logits, _ = self.text_experts_gate(
|
||||
text_hidden_states.to(dtype=torch.float32)
|
||||
)
|
||||
final_hidden_states[text_token_mask] = self.text_experts(
|
||||
text_shared_ouput, text_experts_output = self.text_experts(
|
||||
hidden_states=text_hidden_states, router_logits=text_router_logits
|
||||
).flatten()
|
||||
)
|
||||
final_experts_hidden_states[text_token_mask] = text_experts_output.flatten()
|
||||
if self.has_shared_experts:
|
||||
final_shared_ouput[text_token_mask] = text_shared_ouput.flatten()
|
||||
|
||||
vision_router_logits, _ = self.vision_experts_gate(
|
||||
vision_hidden_states.to(dtype=torch.float32)
|
||||
)
|
||||
final_hidden_states[visual_token_mask] = self.vision_experts(
|
||||
vision_shared_ouput, vision_experts_output = self.vision_experts(
|
||||
hidden_states=vision_hidden_states, router_logits=vision_router_logits
|
||||
).flatten()
|
||||
)
|
||||
final_experts_hidden_states[visual_token_mask] = (
|
||||
vision_experts_output.flatten()
|
||||
)
|
||||
if self.has_shared_experts:
|
||||
final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten()
|
||||
|
||||
final_hidden_states = (final_shared_ouput, final_experts_hidden_states)
|
||||
else:
|
||||
# only text modal input
|
||||
text_router_logits, _ = self.text_experts_gate(
|
||||
@ -374,7 +387,11 @@ class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
)
|
||||
|
||||
if self.has_shared_experts:
|
||||
# for shared_experts model
|
||||
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
|
||||
else:
|
||||
# for not shared_experts model
|
||||
final_hidden_states = final_hidden_states[1]
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user