diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index d262e9e9da50..38c5249380c3 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -120,11 +120,12 @@ class Ernie4_5_MoeMoE(nn.Module): self.gate = ReplicatedLinear(config.hidden_size, config.moe_num_experts, bias=False, + params_dtype=torch.float32, quant_config=None, prefix=f"{prefix}.gate") self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.moe_num_experts)) + torch.empty(config.moe_num_experts, dtype=torch.float32)) self.experts = FusedMoE( num_experts=config.moe_num_experts, @@ -157,7 +158,7 @@ class Ernie4_5_MoeMoE(nn.Module): if self.has_shared_experts: shared_output = self.shared_experts(hidden_states) - router_logits, _ = self.gate(hidden_states) + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index f55016f7ccb3..21772f766b40 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -199,7 +199,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): assert config.moe_num_experts[0] == config.moe_num_experts[1] self.e_score_correction_bias = nn.Parameter( - torch.empty(2, config.moe_num_experts[0])) + torch.empty(2, config.moe_num_experts[0], dtype=torch.float32)) assert text_moe_layer_start_index <= text_moe_layer_end_index @@ -209,6 +209,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): config.hidden_size, config.moe_num_experts[0], bias=False, + params_dtype=torch.float32, quant_config=quant_config, prefix=f"{prefix}.text_experts_gate") @@ -238,6 +239,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): config.hidden_size, config.moe_num_experts[1], bias=False, + params_dtype=torch.float32, quant_config=quant_config, prefix=f"{prefix}.vision_experts_gate") @@ -288,7 +290,8 @@ class Ernie4_5_VLMoeMoE(nn.Module): if visual_token_mask is not None and visual_token_mask.all(): # only vision modal input - router_logits, _ = self.vision_experts_gate(hidden_states) + router_logits, _ = self.vision_experts_gate( + hidden_states.to(dtype=torch.float32)) final_hidden_states = self.vision_experts( hidden_states=hidden_states, router_logits=router_logits) elif visual_token_mask is not None and visual_token_mask.any(): @@ -303,19 +306,21 @@ class Ernie4_5_VLMoeMoE(nn.Module): vision_hidden_states = hidden_states[visual_token_mask].reshape( -1, self.hidden_size) - text_router_logits, _ = self.text_experts_gate(text_hidden_states) + text_router_logits, _ = self.text_experts_gate( + text_hidden_states.to(dtype=torch.float32)) final_hidden_states[text_token_mask] = self.text_experts( hidden_states=text_hidden_states, router_logits=text_router_logits).flatten() vision_router_logits, _ = self.vision_experts_gate( - vision_hidden_states) + vision_hidden_states.to(dtype=torch.float32)) final_hidden_states[visual_token_mask] = self.vision_experts( hidden_states=vision_hidden_states, router_logits=vision_router_logits).flatten() else: # only text modal input - text_router_logits, _ = self.text_experts_gate(hidden_states) + text_router_logits, _ = self.text_experts_gate( + hidden_states.to(dtype=torch.float32)) final_hidden_states = self.text_experts( hidden_states=hidden_states, router_logits=text_router_logits)