[Bugfix] Fix KDA output (#27905)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-11-01 11:54:36 +08:00 committed by GitHub
parent bc4486d609
commit 3a5de7d2d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
) -> torch.Tensor: ) -> None:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0] k = self.k_proj(hidden_states)[0]
@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
) )
core_attn_out = self.o_norm(core_attn_out, g2) core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
return self.o_proj(core_attn_out)[0]
def _forward( def _forward(
self, self,