mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 06:44:26 +08:00
[Bugfix] fix deepseek fp16 scale bug (#14809)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
e1a2c699dd
commit
db10422184
@ -160,14 +160,16 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits) * self.routed_scaling_factor
|
router_logits=router_logits) * self.routed_scaling_factor
|
||||||
else:
|
else:
|
||||||
# This is a special case to avoid FP16 overflow
|
# Fix FP16 overflow
|
||||||
|
# See DeepseekV2DecoderLayer for more details.
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
if hidden_states.dtype != torch.float16:
|
if hidden_states.dtype != torch.float16:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
else:
|
else:
|
||||||
# This is a special case to avoid FP16 overflow
|
# Fix FP16 overflow
|
||||||
|
# See DeepseekV2DecoderLayer for more details.
|
||||||
final_hidden_states = final_hidden_states + shared_output \
|
final_hidden_states = final_hidden_states + shared_output \
|
||||||
* (1. / self.routed_scaling_factor)
|
* (1. / self.routed_scaling_factor)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@ -499,6 +501,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||||
# with the layer's index.
|
# with the layer's index.
|
||||||
layer_idx = int(prefix.split(sep='.')[-1])
|
layer_idx = int(prefix.split(sep='.')[-1])
|
||||||
|
self.layer_idx = layer_idx
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
attn_cls = DeepseekV2MLAAttention
|
attn_cls = DeepseekV2MLAAttention
|
||||||
else:
|
else:
|
||||||
@ -561,19 +564,30 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
if hidden_states.dtype == torch.float16:
|
||||||
if isinstance(self.mlp, DeepseekV2MoE) and \
|
# Fix FP16 overflow
|
||||||
hidden_states.dtype == torch.float16:
|
# We scale both hidden_states and residual before
|
||||||
# This is a special case to avoid FP16 overflow
|
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||||
hidden_states *= 1. / self.routed_scaling_factor
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
|
if self.layer_idx == 0:
|
||||||
|
# The residual is shared by all layers, we only scale it on
|
||||||
|
# first layer.
|
||||||
|
residual *= 1. / self.routed_scaling_factor
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
if isinstance(self.mlp, DeepseekV2MLP) and \
|
|
||||||
hidden_states.dtype == torch.float16:
|
if isinstance(self.mlp,
|
||||||
# This is a special case to avoid FP16 overflow
|
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||||
|
# Fix FP16 overflow
|
||||||
|
# Scaling the DeepseekV2MLP output, it is the input of
|
||||||
|
# input_layernorm of next decoder layer.
|
||||||
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||||
|
# of DeepseekV2MOE
|
||||||
hidden_states *= 1. / self.routed_scaling_factor
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
residual *= 1. / self.routed_scaling_factor
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user