[Perf] Decouple torch op from GDA to leverage torch.compile (#27871)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu 2025-10-31 21:35:52 +08:00 committed by GitHub
parent 933cdea440
commit 3857eb8725
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,18 +40,36 @@ logger = init_logger(__name__)
def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
self._forward(
q_proj_states=q_proj_states,
k_proj_states=k_proj_states,
v_proj_states=v_proj_states,
g1=g1,
g2=g2,
beta=beta,
core_attn_out=core_attn_out,
)
def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
return
@ -60,7 +78,7 @@ def kda_attention_fake(
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
mutates_args=["core_attn_out"],
fake_impl=kda_attention_fake,
)
@ -241,37 +259,56 @@ class KimiDeltaAttention(nn.Module, MambaBase):
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
v = self.v_proj(hidden_states)[0]
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = torch.zeros(
(1, num_tokens, self.local_num_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.kda_attention(
q,
k,
v,
g1,
g2,
beta,
core_attn_out,
self.prefix,
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
return self.o_proj(core_attn_out)[0]
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
# # V1 profile run
return
assert isinstance(attn_metadata, dict)
@ -288,10 +325,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
@ -374,14 +407,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
@ -393,7 +418,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=initial_state,
output_final_state=True,
@ -410,17 +435,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]
assert core_attn_out_non_spec.shape == core_attn_out.shape
core_attn_out[:] = core_attn_out_non_spec