diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fcab533ed2dc5..23b450aeddac9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -160,14 +160,16 @@ class DeepseekV2MoE(nn.Module): hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor else: - # This is a special case to avoid FP16 overflow + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) if shared_output is not None: if hidden_states.dtype != torch.float16: final_hidden_states = final_hidden_states + shared_output else: - # This is a special case to avoid FP16 overflow + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) if self.tp_size > 1: @@ -499,6 +501,7 @@ class DeepseekV2DecoderLayer(nn.Module): # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx if model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: @@ -561,19 +564,30 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states=hidden_states, ) - # Fully Connected - if isinstance(self.mlp, DeepseekV2MoE) and \ - hidden_states.dtype == torch.float16: - # This is a special case to avoid FP16 overflow + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) - if isinstance(self.mlp, DeepseekV2MLP) and \ - hidden_states.dtype == torch.float16: - # This is a special case to avoid FP16 overflow + + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor - residual *= 1. / self.routed_scaling_factor + return hidden_states, residual