[BUGFIX] llama_4_scaling wrongly passed to DeepseekAttention (#29908)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize 2025-12-02 23:51:20 +01:00 committed by GitHub
parent 0a9caca9f5
commit 5e5646e206
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1135,6 +1135,8 @@ class DeepseekV2DecoderLayer(nn.Module):
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
) )
self.use_mha = use_mha
if use_mha: if use_mha:
attn_cls = DeepseekAttention attn_cls = DeepseekAttention
elif model_config.use_mla: elif model_config.use_mla:
@ -1196,11 +1198,14 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions, attn_kwargs = {
hidden_states=hidden_states, "positions": positions,
llama_4_scaling=llama_4_scaling, "hidden_states": hidden_states,
) }
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if ( if (
not isinstance(self.self_attn, DeepseekAttention) not isinstance(self.self_attn, DeepseekAttention)