mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 04:37:04 +08:00
[Bugfix] Fix KDA output (#27905)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
bc4486d609
commit
3a5de7d2d6
@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> None:
|
||||||
num_tokens = hidden_states.size(0)
|
num_tokens = hidden_states.size(0)
|
||||||
q = self.q_proj(hidden_states)[0]
|
q = self.q_proj(hidden_states)[0]
|
||||||
k = self.k_proj(hidden_states)[0]
|
k = self.k_proj(hidden_states)[0]
|
||||||
@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
)
|
)
|
||||||
core_attn_out = self.o_norm(core_attn_out, g2)
|
core_attn_out = self.o_norm(core_attn_out, g2)
|
||||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||||
|
output[:] = self.o_proj(core_attn_out)[0]
|
||||||
return self.o_proj(core_attn_out)[0]
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user