mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 14:10:54 +08:00
[BUGFIX] llama_4_scaling wrongly passed to DeepseekAttention (#29908)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
parent
0a9caca9f5
commit
5e5646e206
@ -1135,6 +1135,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
self.use_mha = use_mha
|
||||
|
||||
if use_mha:
|
||||
attn_cls = DeepseekAttention
|
||||
elif model_config.use_mla:
|
||||
@ -1196,11 +1198,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=llama_4_scaling,
|
||||
)
|
||||
|
||||
attn_kwargs = {
|
||||
"positions": positions,
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
if not self.use_mha:
|
||||
attn_kwargs["llama_4_scaling"] = llama_4_scaling
|
||||
hidden_states = self.self_attn(**attn_kwargs)
|
||||
|
||||
if (
|
||||
not isinstance(self.self_attn, DeepseekAttention)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user