mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 03:47:15 +08:00
[Perf] Decouple torch op from GDA to leverage torch.compile (#27871)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
parent
933cdea440
commit
3857eb8725
@ -40,18 +40,36 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def kda_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states, output=output)
|
||||
self._forward(
|
||||
q_proj_states=q_proj_states,
|
||||
k_proj_states=k_proj_states,
|
||||
v_proj_states=v_proj_states,
|
||||
g1=g1,
|
||||
g2=g2,
|
||||
beta=beta,
|
||||
core_attn_out=core_attn_out,
|
||||
)
|
||||
|
||||
|
||||
def kda_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
@ -60,7 +78,7 @@ def kda_attention_fake(
|
||||
direct_register_custom_op(
|
||||
op_name="kda_attention",
|
||||
op_func=kda_attention,
|
||||
mutates_args=["output"],
|
||||
mutates_args=["core_attn_out"],
|
||||
fake_impl=kda_attention_fake,
|
||||
)
|
||||
|
||||
@ -241,37 +259,56 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
return torch.ops.vllm.kda_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = hidden_states.size(0)
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
k = self.k_proj(hidden_states)[0]
|
||||
v = self.v_proj(hidden_states)[0]
|
||||
|
||||
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
||||
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
||||
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
||||
beta = beta.unsqueeze(0)
|
||||
g1 = g1.unsqueeze(0)
|
||||
|
||||
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
||||
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
||||
|
||||
core_attn_out = torch.zeros(
|
||||
(1, num_tokens, self.local_num_heads, self.head_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
torch.ops.vllm.kda_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g1,
|
||||
g2,
|
||||
beta,
|
||||
core_attn_out,
|
||||
self.prefix,
|
||||
)
|
||||
core_attn_out = self.o_norm(core_attn_out, g2)
|
||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||
|
||||
return self.o_proj(core_attn_out)[0]
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
q_proj_states: torch.Tensor,
|
||||
k_proj_states: torch.Tensor,
|
||||
v_proj_states: torch.Tensor,
|
||||
g1: torch.Tensor,
|
||||
g2: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
# Mimic the memory allocation in the real run
|
||||
q = torch.empty_like(hidden_states)
|
||||
k = torch.empty_like(hidden_states)
|
||||
v = torch.empty_like(hidden_states)
|
||||
g = hidden_states.new_empty(
|
||||
hidden_states.size(0),
|
||||
self.local_num_heads,
|
||||
self.head_dim,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
beta = torch.empty(
|
||||
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
|
||||
)
|
||||
core_attn_out = torch.empty_like(hidden_states)
|
||||
# # V1 profile run
|
||||
return
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@ -288,10 +325,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
conv_state_k = conv_state_k.transpose(-1, -2)
|
||||
conv_state_v = conv_state_v.transpose(-1, -2)
|
||||
|
||||
q_proj_states = self.q_proj(hidden_states)[0]
|
||||
k_proj_states = self.k_proj(hidden_states)[0]
|
||||
v_proj_states = self.v_proj(hidden_states)[0]
|
||||
|
||||
q_conv_weights = self.q_conv1d.weight.view(
|
||||
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
||||
)
|
||||
@ -374,14 +407,6 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
|
||||
)
|
||||
|
||||
beta = self.b_proj(hidden_states)[0].float().sigmoid()
|
||||
|
||||
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
|
||||
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
|
||||
|
||||
beta = beta.unsqueeze(0)
|
||||
g = g.unsqueeze(0)
|
||||
|
||||
if attn_metadata.num_prefills > 0:
|
||||
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
|
||||
recurrent_state[zero_idx] = 0
|
||||
@ -393,7 +418,7 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
g=g1,
|
||||
beta=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
@ -410,17 +435,12 @@ class KimiDeltaAttention(nn.Module, MambaBase):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
g=g1,
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
ssm_state_indices=non_spec_state_indices_tensor,
|
||||
)
|
||||
|
||||
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
||||
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
||||
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
|
||||
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
||||
|
||||
output[:] = self.o_proj(core_attn_out)[0]
|
||||
assert core_attn_out_non_spec.shape == core_attn_out.shape
|
||||
core_attn_out[:] = core_attn_out_non_spec
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user