mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:24:30 +08:00
[Kernel] Support DCP for Triton backend (#25132)
Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
52d0cb8458
commit
05c19485a5
@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||
|
||||
o1 = torch.zeros_like(o)
|
||||
lse1 = torch.zeros_like(lse)
|
||||
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o1,
|
||||
lse1,
|
||||
req_to_page,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
|
||||
@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
lse,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
stride_lse_bs,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
lse_val = e_max + tl.log(e_sum)
|
||||
tl.store(
|
||||
lse + cur_batch * stride_lse_bs + cur_head,
|
||||
lse_val,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
lse,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
lse,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
lse.stride(0),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
@ -633,6 +645,7 @@ def decode_attention_fwd(
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
@ -651,6 +664,7 @@ def decode_attention_fwd(
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
@ -666,6 +680,7 @@ def decode_attention_fwd(
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
lse,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
|
||||
@ -685,7 +685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
residual = hidden_states.clone()
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
|
||||
@ -32,6 +32,7 @@ class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -139,19 +140,20 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
q_num_heads = q.shape[1]
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
q_num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
q_num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
@ -167,9 +169,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return o, None
|
||||
return o, lse
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user