From fa183e92713456dec682088a362dd9908100cc03 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Thu, 13 Nov 2025 15:59:58 +0800 Subject: [PATCH] [Bugfix] fix kimi-linear crash (#28445) Signed-off-by: zjy0516 --- vllm/model_executor/layers/kda.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 26458f2e3c4d..2e7500bac718 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -44,7 +44,6 @@ def kda_attention( k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, @@ -56,7 +55,6 @@ def kda_attention( k_proj_states=k_proj_states, v_proj_states=v_proj_states, g1=g1, - g2=g2, beta=beta, core_attn_out=core_attn_out, ) @@ -67,7 +65,6 @@ def kda_attention_fake( k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, layer_name: str, @@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase): k, v, g1, - g2, beta, core_attn_out, self.prefix, @@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase): k_proj_states: torch.Tensor, v_proj_states: torch.Tensor, g1: torch.Tensor, - g2: torch.Tensor, beta: torch.Tensor, core_attn_out: torch.Tensor, ) -> None: @@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase): has_initial_state = attn_metadata.has_initial_state non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + num_actual_tokens = attn_metadata.num_actual_tokens constant_caches = self.kv_cache[forward_context.virtual_engine] + q_proj_states = q_proj_states[:num_actual_tokens] + k_proj_states = k_proj_states[:num_actual_tokens] + v_proj_states = v_proj_states[:num_actual_tokens] + g1 = g1[:num_actual_tokens] + beta = beta[:num_actual_tokens] + (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches # deal with strides conv_state_q = conv_state_q.transpose(-1, -2) @@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase): ).transpose(0, 1) else: decode_conv_indices = non_spec_state_indices_tensor[ - : attn_metadata.num_decodes + : attn_metadata.num_actual_tokens ] q = causal_conv1d_update( q_proj_states, @@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase): beta=beta, initial_state=recurrent_state, use_qk_l2norm_in_kernel=True, - cu_seqlens=non_spec_query_start_loc, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], ssm_state_indices=non_spec_state_indices_tensor, ) - assert core_attn_out_non_spec.shape == core_attn_out.shape - core_attn_out[:] = core_attn_out_non_spec + core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[ + 0, :num_actual_tokens + ]