[Kernel] Support DCP for Triton backend (#25132)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei Wei 2025-09-24 18:09:34 -07:00 committed by GitHub
parent 52d0cb8458
commit 05c19485a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 8 deletions

View File

@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
lse1 = torch.zeros_like(lse)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
lse1,
req_to_page,
b_seq_len,
attn_logits,

View File

@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2(
Mid_O,
o,
lse,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
stride_lse_bs,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
acc / e_sum,
mask=mask_d,
)
lse_val = e_max + tl.log(e_sum)
tl.store(
lse + cur_batch * stride_lse_bs + cur_head,
lse_val,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
lse,
v_buffer,
b_seq_len,
num_kv_splits,
@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid](
logits,
o,
lse,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits)
@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits)
@ -633,6 +645,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@ -651,6 +664,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@ -666,6 +680,7 @@ def decode_attention_fwd(
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,

View File

@ -685,7 +685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(

View File

@ -32,6 +32,7 @@ class TritonMLABackend(MLACommonBackend):
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__(
self,
@ -139,19 +140,20 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert isinstance(q, torch.Tensor)
B = q.shape[0]
q_num_heads = q.shape[1]
o = torch.zeros(B,
self.num_heads,
q_num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
q_num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
@ -167,9 +169,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return o, None
return o, lse