[Bugfix] fix kimi-linear crash (#28445)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-11-13 15:59:58 +08:00 committed by GitHub
parent 4ab34f6ef1
commit fa183e9271
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,7 +44,6 @@ def kda_attention(
k_proj_states: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor, v_proj_states: torch.Tensor,
g1: torch.Tensor, g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor, beta: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
@ -56,7 +55,6 @@ def kda_attention(
k_proj_states=k_proj_states, k_proj_states=k_proj_states,
v_proj_states=v_proj_states, v_proj_states=v_proj_states,
g1=g1, g1=g1,
g2=g2,
beta=beta, beta=beta,
core_attn_out=core_attn_out, core_attn_out=core_attn_out,
) )
@ -67,7 +65,6 @@ def kda_attention_fake(
k_proj_states: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor, v_proj_states: torch.Tensor,
g1: torch.Tensor, g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor, beta: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
layer_name: str, layer_name: str,
@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
k, k,
v, v,
g1, g1,
g2,
beta, beta,
core_attn_out, core_attn_out,
self.prefix, self.prefix,
@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
k_proj_states: torch.Tensor, k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor, v_proj_states: torch.Tensor,
g1: torch.Tensor, g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor, beta: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
) -> None: ) -> None:
@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase):
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[forward_context.virtual_engine] constant_caches = self.kv_cache[forward_context.virtual_engine]
q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens]
v_proj_states = v_proj_states[:num_actual_tokens]
g1 = g1[:num_actual_tokens]
beta = beta[:num_actual_tokens]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides # deal with strides
conv_state_q = conv_state_q.transpose(-1, -2) conv_state_q = conv_state_q.transpose(-1, -2)
@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
).transpose(0, 1) ).transpose(0, 1)
else: else:
decode_conv_indices = non_spec_state_indices_tensor[ decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_decodes : attn_metadata.num_actual_tokens
] ]
q = causal_conv1d_update( q = causal_conv1d_update(
q_proj_states, q_proj_states,
@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
beta=beta, beta=beta,
initial_state=recurrent_state, initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor, ssm_state_indices=non_spec_state_indices_tensor,
) )
assert core_attn_out_non_spec.shape == core_attn_out.shape core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
core_attn_out[:] = core_attn_out_non_spec 0, :num_actual_tokens
]