[Bugfix] fix deepseek fp16 scale bug (#14809)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin 2025-04-09 04:56:09 +08:00 committed by GitHub
parent e1a2c699dd
commit db10422184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -160,14 +160,16 @@ class DeepseekV2MoE(nn.Module):
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
else:
# This is a special case to avoid FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
@ -499,6 +501,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
@ -561,19 +564,30 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states,
)
# Fully Connected
if isinstance(self.mlp, DeepseekV2MoE) and \
hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP) and \
hidden_states.dtype == torch.float16:
# This is a special case to avoid FP16 overflow
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor
return hidden_states, residual