diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 308bc8be1dece..26458f2e3c4da 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -259,7 +259,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): hidden_states: torch.Tensor, positions: torch.Tensor, output: torch.Tensor, - ) -> torch.Tensor: + ) -> None: num_tokens = hidden_states.size(0) q = self.q_proj(hidden_states)[0] k = self.k_proj(hidden_states)[0] @@ -291,8 +291,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ) core_attn_out = self.o_norm(core_attn_out, g2) core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") - - return self.o_proj(core_attn_out)[0] + output[:] = self.o_proj(core_attn_out)[0] def _forward( self,