[BUGFIX] llama_4_scaling wrongly passed to DeepseekAttention (#29908)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize 2025-12-02 23:51:20 +01:00 committed by GitHub
parent 0a9caca9f5
commit 5e5646e206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1135,6 +1135,8 @@ class DeepseekV2DecoderLayer(nn.Module):
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
self.use_mha = use_mha
if use_mha:
attn_cls = DeepseekAttention
elif model_config.use_mla:
@ -1196,11 +1198,14 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if (
not isinstance(self.self_attn, DeepseekAttention)