[Bugfix][Model]fix ernie45 moe gate&bias dtype to float32 (#25936)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076 2025-09-30 19:11:21 +08:00 committed by GitHub
parent 1ad3aca682
commit ef6e0e7132
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 7 deletions

View File

@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.moe_num_experts))
torch.empty(config.moe_num_experts, dtype=torch.float32))
self.experts = FusedMoE(
num_experts=config.moe_num_experts,
@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module):
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)

View File

@ -199,7 +199,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
assert config.moe_num_experts[0] == config.moe_num_experts[1]
self.e_score_correction_bias = nn.Parameter(
torch.empty(2, config.moe_num_experts[0]))
torch.empty(2, config.moe_num_experts[0], dtype=torch.float32))
assert text_moe_layer_start_index <= text_moe_layer_end_index
@ -209,6 +209,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config.hidden_size,
config.moe_num_experts[0],
bias=False,
params_dtype=torch.float32,
quant_config=quant_config,
prefix=f"{prefix}.text_experts_gate")
@ -238,6 +239,7 @@ class Ernie4_5_VLMoeMoE(nn.Module):
config.hidden_size,
config.moe_num_experts[1],
bias=False,
params_dtype=torch.float32,
quant_config=quant_config,
prefix=f"{prefix}.vision_experts_gate")
@ -288,7 +290,8 @@ class Ernie4_5_VLMoeMoE(nn.Module):
if visual_token_mask is not None and visual_token_mask.all():
# only vision modal input
router_logits, _ = self.vision_experts_gate(hidden_states)
router_logits, _ = self.vision_experts_gate(
hidden_states.to(dtype=torch.float32))
final_hidden_states = self.vision_experts(
hidden_states=hidden_states, router_logits=router_logits)
elif visual_token_mask is not None and visual_token_mask.any():
@ -303,19 +306,21 @@ class Ernie4_5_VLMoeMoE(nn.Module):
vision_hidden_states = hidden_states[visual_token_mask].reshape(
-1, self.hidden_size)
text_router_logits, _ = self.text_experts_gate(text_hidden_states)
text_router_logits, _ = self.text_experts_gate(
text_hidden_states.to(dtype=torch.float32))
final_hidden_states[text_token_mask] = self.text_experts(
hidden_states=text_hidden_states,
router_logits=text_router_logits).flatten()
vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states)
vision_hidden_states.to(dtype=torch.float32))
final_hidden_states[visual_token_mask] = self.vision_experts(
hidden_states=vision_hidden_states,
router_logits=vision_router_logits).flatten()
else:
# only text modal input
text_router_logits, _ = self.text_experts_gate(hidden_states)
text_router_logits, _ = self.text_experts_gate(
hidden_states.to(dtype=torch.float32))
final_hidden_states = self.text_experts(
hidden_states=hidden_states, router_logits=text_router_logits)