mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Bugfix] fix kimi-linear crash (#28445)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
4ab34f6ef1
commit
fa183e9271
@ -44,7 +44,6 @@ def kda_attention(
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
@ -56,7 +55,6 @@ def kda_attention(
|
||||
k_proj_states=k_proj_states,
|
||||
v_proj_states=v_proj_states,
|
||||
g1=g1,
|
||||
g2=g2,
|
||||
beta=beta,
|
||||
core_attn_out=core_attn_out,
|
||||
)
|
||||
@ -67,7 +65,6 @@ def kda_attention_fake(
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
k,
|
||||
v,
|
||||
g1,
|
||||
g2,
|
||||
beta,
|
||||
core_attn_out,
|
||||
self.prefix,
|
||||
@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
) -> None:
|
||||
@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
constant_caches = self.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
q_proj_states = q_proj_states[:num_actual_tokens]
|
||||
k_proj_states = k_proj_states[:num_actual_tokens]
|
||||
v_proj_states = v_proj_states[:num_actual_tokens]
|
||||
g1 = g1[:num_actual_tokens]
|
||||
beta = beta[:num_actual_tokens]
|
||||
|
||||
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
|
||||
# deal with strides
|
||||
conv_state_q = conv_state_q.transpose(-1, -2)
|
||||
@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
decode_conv_indices = non_spec_state_indices_tensor[
|
||||
: attn_metadata.num_decodes
|
||||
: attn_metadata.num_actual_tokens
|
||||
]
|
||||
q = causal_conv1d_update(
|
||||
q_proj_states,
|
||||
@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
)
|
||||
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
||||
core_attn_out[:] = core_attn_out_non_spec
|
||||
core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
|
||||
0, :num_actual_tokens
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user