[Deepseek-V3.2][Kernel] Integrate cuda indexer k cache gather (#26456)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-10-15 04:05:01 -04:00 committed by GitHub
parent efdef57b1f
commit f5ed68ef63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -75,7 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
@ -483,69 +483,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
return DeepseekV32IndexerBackend
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim + 1]
dst_value, # [cu_seq_lens[-1], head_dim]
dst_scale, # [cu_seq_lens[-1], 4]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
num_blocks, block_size, _ = kv_cache.shape
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32)
non_remaining_value = kv_cache[
blocks[full_block], : block_size * head_dim
].view(-1, head_dim)
non_remaining_scale = kv_cache[
blocks[full_block], block_size * head_dim :
].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat(
[
non_remaining_value,
kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim),
],
dim=0,
)
scale = torch.cat(
[
non_remaining_scale,
kv_cache[
blocks[-1],
block_size * head_dim : block_size * head_dim + remaining * 4,
].view(-1, 4),
],
dim=0,
)
expected_value.append(value)
expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.float8_e4m3fn)
gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
@ -605,19 +542,20 @@ def sparse_attn_indexer(
dtype=torch.float8_e4m3fn,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
cp_gather_indexer_k_quant_cache(
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale),
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,