mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Bugfix] fix kimi-linear crash (#28445)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
4ab34f6ef1
commit
fa183e9271
@ -44,7 +44,6 @@ def kda_attention(
|
|||||||
k_proj_states: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
v_proj_states: torch.Tensor,
|
v_proj_states: torch.Tensor,
|
||||||
g1: torch.Tensor,
|
g1: torch.Tensor,
|
||||||
g2: torch.Tensor,
|
|
||||||
beta: torch.Tensor,
|
beta: torch.Tensor,
|
||||||
core_attn_out: torch.Tensor,
|
core_attn_out: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
@ -56,7 +55,6 @@ def kda_attention(
|
|||||||
k_proj_states=k_proj_states,
|
k_proj_states=k_proj_states,
|
||||||
v_proj_states=v_proj_states,
|
v_proj_states=v_proj_states,
|
||||||
g1=g1,
|
g1=g1,
|
||||||
g2=g2,
|
|
||||||
beta=beta,
|
beta=beta,
|
||||||
core_attn_out=core_attn_out,
|
core_attn_out=core_attn_out,
|
||||||
)
|
)
|
||||||
@ -67,7 +65,6 @@ def kda_attention_fake(
|
|||||||
k_proj_states: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
v_proj_states: torch.Tensor,
|
v_proj_states: torch.Tensor,
|
||||||
g1: torch.Tensor,
|
g1: torch.Tensor,
|
||||||
g2: torch.Tensor,
|
|
||||||
beta: torch.Tensor,
|
beta: torch.Tensor,
|
||||||
core_attn_out: torch.Tensor,
|
core_attn_out: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
@ -284,7 +281,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
g1,
|
g1,
|
||||||
g2,
|
|
||||||
beta,
|
beta,
|
||||||
core_attn_out,
|
core_attn_out,
|
||||||
self.prefix,
|
self.prefix,
|
||||||
@ -299,7 +295,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
k_proj_states: torch.Tensor,
|
k_proj_states: torch.Tensor,
|
||||||
v_proj_states: torch.Tensor,
|
v_proj_states: torch.Tensor,
|
||||||
g1: torch.Tensor,
|
g1: torch.Tensor,
|
||||||
g2: torch.Tensor,
|
|
||||||
beta: torch.Tensor,
|
beta: torch.Tensor,
|
||||||
core_attn_out: torch.Tensor,
|
core_attn_out: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -316,8 +311,15 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
has_initial_state = attn_metadata.has_initial_state
|
has_initial_state = attn_metadata.has_initial_state
|
||||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
constant_caches = self.kv_cache[forward_context.virtual_engine]
|
constant_caches = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
|
||||||
|
q_proj_states = q_proj_states[:num_actual_tokens]
|
||||||
|
k_proj_states = k_proj_states[:num_actual_tokens]
|
||||||
|
v_proj_states = v_proj_states[:num_actual_tokens]
|
||||||
|
g1 = g1[:num_actual_tokens]
|
||||||
|
beta = beta[:num_actual_tokens]
|
||||||
|
|
||||||
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
|
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
|
||||||
# deal with strides
|
# deal with strides
|
||||||
conv_state_q = conv_state_q.transpose(-1, -2)
|
conv_state_q = conv_state_q.transpose(-1, -2)
|
||||||
@ -372,7 +374,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
).transpose(0, 1)
|
).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
decode_conv_indices = non_spec_state_indices_tensor[
|
decode_conv_indices = non_spec_state_indices_tensor[
|
||||||
: attn_metadata.num_decodes
|
: attn_metadata.num_actual_tokens
|
||||||
]
|
]
|
||||||
q = causal_conv1d_update(
|
q = causal_conv1d_update(
|
||||||
q_proj_states,
|
q_proj_states,
|
||||||
@ -438,8 +440,9 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
|||||||
beta=beta,
|
beta=beta,
|
||||||
initial_state=recurrent_state,
|
initial_state=recurrent_state,
|
||||||
use_qk_l2norm_in_kernel=True,
|
use_qk_l2norm_in_kernel=True,
|
||||||
cu_seqlens=non_spec_query_start_loc,
|
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
|
||||||
ssm_state_indices=non_spec_state_indices_tensor,
|
ssm_state_indices=non_spec_state_indices_tensor,
|
||||||
)
|
)
|
||||||
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
core_attn_out[0, :num_actual_tokens] = core_attn_out_non_spec[
|
||||||
core_attn_out[:] = core_attn_out_non_spec
|
0, :num_actual_tokens
|
||||||
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user