[Bugfix] Fix KDA output (#27905)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-01 11:54:36 +08:00 committed by GitHub
parent bc4486d609
commit 3a5de7d2d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> torch.Tensor:
) -> None:
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
return self.o_proj(core_attn_out)[0]
output[:] = self.o_proj(core_attn_out)[0]
def _forward(
self,