mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 00:04:39 +08:00
[Kernel] Support DCP for Triton backend (#25132)
Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
52d0cb8458
commit
05c19485a5
@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
|||||||
# o will have the same shape as q
|
# o will have the same shape as q
|
||||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
||||||
|
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
|||||||
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||||
|
|
||||||
o1 = torch.zeros_like(o)
|
o1 = torch.zeros_like(o)
|
||||||
|
lse1 = torch.zeros_like(lse)
|
||||||
|
|
||||||
decode_attention_fwd(
|
decode_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o1,
|
o1,
|
||||||
|
lse1,
|
||||||
req_to_page,
|
req_to_page,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
|
|||||||
@ -474,12 +474,14 @@ def _decode_grouped_att_m_fwd(
|
|||||||
def _fwd_kernel_stage2(
|
def _fwd_kernel_stage2(
|
||||||
Mid_O,
|
Mid_O,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
stride_mid_ob,
|
stride_mid_ob,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
stride_mid_os,
|
stride_mid_os,
|
||||||
stride_obs,
|
stride_obs,
|
||||||
stride_oh,
|
stride_oh,
|
||||||
|
stride_lse_bs,
|
||||||
NUM_KV_SPLITS: tl.constexpr,
|
NUM_KV_SPLITS: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
@ -525,12 +527,18 @@ def _fwd_kernel_stage2(
|
|||||||
acc / e_sum,
|
acc / e_sum,
|
||||||
mask=mask_d,
|
mask=mask_d,
|
||||||
)
|
)
|
||||||
|
lse_val = e_max + tl.log(e_sum)
|
||||||
|
tl.store(
|
||||||
|
lse + cur_batch * stride_lse_bs + cur_head,
|
||||||
|
lse_val,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _decode_softmax_reducev_fwd(
|
def _decode_softmax_reducev_fwd(
|
||||||
logits,
|
logits,
|
||||||
q,
|
q,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
@ -555,12 +563,14 @@ def _decode_softmax_reducev_fwd(
|
|||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logits,
|
logits,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
logits.stride(2),
|
logits.stride(2),
|
||||||
o.stride(0),
|
o.stride(0),
|
||||||
o.stride(1),
|
o.stride(1),
|
||||||
|
lse.stride(0),
|
||||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
@ -575,6 +585,7 @@ def decode_attention_fwd_normal(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@ -595,7 +606,7 @@ def decode_attention_fwd_normal(
|
|||||||
page_size,
|
page_size,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||||
num_kv_splits)
|
num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
@ -604,6 +615,7 @@ def decode_attention_fwd_grouped(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@ -624,7 +636,7 @@ def decode_attention_fwd_grouped(
|
|||||||
page_size,
|
page_size,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
_decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len,
|
||||||
num_kv_splits)
|
num_kv_splits)
|
||||||
|
|
||||||
|
|
||||||
@ -633,6 +645,7 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@ -651,6 +664,7 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
@ -666,6 +680,7 @@ def decode_attention_fwd(
|
|||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
|
lse,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
attn_logits,
|
attn_logits,
|
||||||
|
|||||||
@ -685,7 +685,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states.clone()
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class TritonMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
|
|
||||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||||
|
can_return_lse_for_decode: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -139,19 +140,20 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
assert isinstance(q, torch.Tensor)
|
assert isinstance(q, torch.Tensor)
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
|
q_num_heads = q.shape[1]
|
||||||
o = torch.zeros(B,
|
o = torch.zeros(B,
|
||||||
self.num_heads,
|
q_num_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
dtype=q.dtype,
|
dtype=q.dtype,
|
||||||
device=q.device)
|
device=q.device)
|
||||||
|
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||||
num_kv_splits = 4 # TODO: heuristic
|
num_kv_splits = 4 # TODO: heuristic
|
||||||
|
|
||||||
# TODO(lucas) Allocate ahead of time
|
# TODO(lucas) Allocate ahead of time
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
(
|
(
|
||||||
B,
|
B,
|
||||||
self.num_heads,
|
q_num_heads,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||||
# just mirror that
|
# just mirror that
|
||||||
@ -167,9 +169,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||||
|
|
||||||
# Run MQA
|
# Run MQA
|
||||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse,
|
||||||
attn_metadata.decode.block_table,
|
attn_metadata.decode.block_table,
|
||||||
attn_metadata.decode.seq_lens, attn_logits,
|
attn_metadata.decode.seq_lens, attn_logits,
|
||||||
num_kv_splits, self.scale, PAGE_SIZE)
|
num_kv_splits, self.scale, PAGE_SIZE)
|
||||||
|
|
||||||
return o, None
|
return o, lse
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user