[Kernel] Support DCP for Triton backend (#25132)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei Wei 2025-09-24 18:09:34 -07:00 committed by GitHub
parent 52d0cb8458
commit 05c19485a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 8 deletions

View File

@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda") b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o) o1 = torch.zeros_like(o)
lse1 = torch.zeros_like(lse)
decode_attention_fwd( decode_attention_fwd(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
o1, o1,
lse1,
req_to_page, req_to_page,
b_seq_len, b_seq_len,
attn_logits, attn_logits,

View File

@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2( def _fwd_kernel_stage2(
Mid_O, Mid_O,
o, o,
lse,
B_Seqlen, B_Seqlen,
stride_mid_ob, stride_mid_ob,
stride_mid_oh, stride_mid_oh,
stride_mid_os, stride_mid_os,
stride_obs, stride_obs,
stride_oh, stride_oh,
stride_lse_bs,
NUM_KV_SPLITS: tl.constexpr, NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr, BLOCK_DV: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
acc / e_sum, acc / e_sum,
mask=mask_d, mask=mask_d,
) )
lse_val = e_max + tl.log(e_sum)
tl.store(
lse + cur_batch * stride_lse_bs + cur_head,
lse_val,
)
def _decode_softmax_reducev_fwd( def _decode_softmax_reducev_fwd(
logits, logits,
q, q,
o, o,
lse,
v_buffer, v_buffer,
b_seq_len, b_seq_len,
num_kv_splits, num_kv_splits,
@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logits, logits,
o, o,
lse,
b_seq_len, b_seq_len,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
logits.stride(2), logits.stride(2),
o.stride(0), o.stride(0),
o.stride(1), o.stride(1),
lse.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS, NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
Lv=Lv, Lv=Lv,
@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
page_size, page_size,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits) num_kv_splits)
@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
page_size, page_size,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
num_kv_splits) num_kv_splits)
@ -633,6 +645,7 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
@ -651,6 +664,7 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,
@ -666,6 +680,7 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
lse,
req_to_token, req_to_token,
b_seq_len, b_seq_len,
attn_logits, attn_logits,

View File

@ -685,7 +685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(

View File

@ -32,6 +32,7 @@ class TritonMLABackend(MLACommonBackend):
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
def __init__( def __init__(
self, self,
@ -139,19 +140,20 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert isinstance(q, torch.Tensor) assert isinstance(q, torch.Tensor)
B = q.shape[0] B = q.shape[0]
q_num_heads = q.shape[1]
o = torch.zeros(B, o = torch.zeros(B,
self.num_heads, q_num_heads,
self.kv_lora_rank, self.kv_lora_rank,
dtype=q.dtype, dtype=q.dtype,
device=q.device) device=q.device)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time # TODO(lucas) Allocate ahead of time
attn_logits = torch.empty( attn_logits = torch.empty(
( (
B, B,
self.num_heads, q_num_heads,
num_kv_splits, num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we # NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that # just mirror that
@ -167,9 +169,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
attn_metadata.decode.block_table, attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits, attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE) num_kv_splits, self.scale, PAGE_SIZE)
return o, None return o, lse