diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d8a081af125c5..a8eb4a69b6f2b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1135,6 +1135,8 @@ class DeepseekV2DecoderLayer(nn.Module): dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) ) + self.use_mha = use_mha + if use_mha: attn_cls = DeepseekAttention elif model_config.use_mla: @@ -1196,11 +1198,14 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - llama_4_scaling=llama_4_scaling, - ) + + attn_kwargs = { + "positions": positions, + "hidden_states": hidden_states, + } + if not self.use_mha: + attn_kwargs["llama_4_scaling"] = llama_4_scaling + hidden_states = self.self_attn(**attn_kwargs) if ( not isinstance(self.self_attn, DeepseekAttention)