mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 16:27:54 +08:00
[Deepseek-V3.2][Kernel] Integrate cuda indexer k cache gather (#26456)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
efdef57b1f
commit
f5ed68ef63
@ -75,7 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||||
from vllm.v1.attention.backends.mla.indexer import (
|
from vllm.v1.attention.backends.mla.indexer import (
|
||||||
DeepseekV32IndexerBackend,
|
DeepseekV32IndexerBackend,
|
||||||
@ -483,69 +483,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
|||||||
return DeepseekV32IndexerBackend
|
return DeepseekV32IndexerBackend
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def cp_gather_indexer_k_quant_cache(
|
|
||||||
kv_cache, # [num_blocks, block_size, head_dim + 1]
|
|
||||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
|
||||||
dst_scale, # [cu_seq_lens[-1], 4]
|
|
||||||
block_table, # [batch_size, num_blocks]
|
|
||||||
cu_seq_lens, # [batch_size + 1, ]
|
|
||||||
batch_size,
|
|
||||||
):
|
|
||||||
num_blocks, block_size, _ = kv_cache.shape
|
|
||||||
head_dim = dst_value.shape[-1]
|
|
||||||
kv_cache = kv_cache.view(num_blocks, -1)
|
|
||||||
|
|
||||||
expected_value = []
|
|
||||||
expected_scale = []
|
|
||||||
for b in range(batch_size):
|
|
||||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
|
||||||
if s == 0:
|
|
||||||
continue
|
|
||||||
tot = cdiv(s, block_size)
|
|
||||||
blocks = block_table[b, :tot]
|
|
||||||
|
|
||||||
value = []
|
|
||||||
scale = []
|
|
||||||
full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32)
|
|
||||||
non_remaining_value = kv_cache[
|
|
||||||
blocks[full_block], : block_size * head_dim
|
|
||||||
].view(-1, head_dim)
|
|
||||||
non_remaining_scale = kv_cache[
|
|
||||||
blocks[full_block], block_size * head_dim :
|
|
||||||
].view(-1, 4)
|
|
||||||
|
|
||||||
remaining = s - (tot - 1) * block_size
|
|
||||||
|
|
||||||
value = torch.cat(
|
|
||||||
[
|
|
||||||
non_remaining_value,
|
|
||||||
kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim),
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
scale = torch.cat(
|
|
||||||
[
|
|
||||||
non_remaining_scale,
|
|
||||||
kv_cache[
|
|
||||||
blocks[-1],
|
|
||||||
block_size * head_dim : block_size * head_dim + remaining * 4,
|
|
||||||
].view(-1, 4),
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_value.append(value)
|
|
||||||
expected_scale.append(scale)
|
|
||||||
|
|
||||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
|
||||||
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
|
||||||
gather_value = gather_value.view(torch.float8_e4m3fn)
|
|
||||||
gather_scale = gather_scale.view(torch.float32)
|
|
||||||
dst_value.copy_(gather_value)
|
|
||||||
dst_scale.copy_(gather_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_attn_indexer(
|
def sparse_attn_indexer(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
k_cache_prefix: str,
|
k_cache_prefix: str,
|
||||||
@ -605,19 +542,20 @@ def sparse_attn_indexer(
|
|||||||
dtype=torch.float8_e4m3fn,
|
dtype=torch.float8_e4m3fn,
|
||||||
)
|
)
|
||||||
k_scale = torch.empty(
|
k_scale = torch.empty(
|
||||||
[chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32
|
[chunk.total_seq_lens, 4],
|
||||||
|
device=k.device,
|
||||||
|
dtype=torch.uint8,
|
||||||
)
|
)
|
||||||
cp_gather_indexer_k_quant_cache(
|
ops.cp_gather_indexer_k_quant_cache(
|
||||||
kv_cache,
|
kv_cache,
|
||||||
k_fp8,
|
k_fp8,
|
||||||
k_scale,
|
k_scale,
|
||||||
chunk.block_table,
|
chunk.block_table,
|
||||||
chunk.cu_seq_lens,
|
chunk.cu_seq_lens,
|
||||||
chunk.num_reqs,
|
|
||||||
)
|
)
|
||||||
logits = fp8_mqa_logits(
|
logits = fp8_mqa_logits(
|
||||||
q_fp8[chunk.token_start : chunk.token_end],
|
q_fp8[chunk.token_start : chunk.token_end],
|
||||||
(k_fp8, k_scale),
|
(k_fp8, k_scale.view(torch.float32)),
|
||||||
weights[chunk.token_start : chunk.token_end],
|
weights[chunk.token_start : chunk.token_end],
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user