From f5ed68ef63d0c3c084688fe00b3aeb1996ca0b6f Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Wed, 15 Oct 2025 04:05:01 -0400 Subject: [PATCH] [Deepseek-V3.2][Kernel] Integrate cuda indexer k cache gather (#26456) Signed-off-by: Yongye Zhu --- vllm/model_executor/models/deepseek_v2.py | 74 ++--------------------- 1 file changed, 6 insertions(+), 68 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3d26327c732ea..f33ed735f4291 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -75,7 +75,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, @@ -483,69 +483,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): return DeepseekV32IndexerBackend -@torch.inference_mode() -def cp_gather_indexer_k_quant_cache( - kv_cache, # [num_blocks, block_size, head_dim + 1] - dst_value, # [cu_seq_lens[-1], head_dim] - dst_scale, # [cu_seq_lens[-1], 4] - block_table, # [batch_size, num_blocks] - cu_seq_lens, # [batch_size + 1, ] - batch_size, -): - num_blocks, block_size, _ = kv_cache.shape - head_dim = dst_value.shape[-1] - kv_cache = kv_cache.view(num_blocks, -1) - - expected_value = [] - expected_scale = [] - for b in range(batch_size): - s = cu_seq_lens[b + 1] - cu_seq_lens[b] - if s == 0: - continue - tot = cdiv(s, block_size) - blocks = block_table[b, :tot] - - value = [] - scale = [] - full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) - non_remaining_value = kv_cache[ - blocks[full_block], : block_size * head_dim - ].view(-1, head_dim) - non_remaining_scale = kv_cache[ - blocks[full_block], block_size * head_dim : - ].view(-1, 4) - - remaining = s - (tot - 1) * block_size - - value = torch.cat( - [ - non_remaining_value, - kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim), - ], - dim=0, - ) - scale = torch.cat( - [ - non_remaining_scale, - kv_cache[ - blocks[-1], - block_size * head_dim : block_size * head_dim + remaining * 4, - ].view(-1, 4), - ], - dim=0, - ) - - expected_value.append(value) - expected_scale.append(scale) - - gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) - gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) - gather_value = gather_value.view(torch.float8_e4m3fn) - gather_scale = gather_scale.view(torch.float32) - dst_value.copy_(gather_value) - dst_scale.copy_(gather_scale) - - def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -605,19 +542,20 @@ def sparse_attn_indexer( dtype=torch.float8_e4m3fn, ) k_scale = torch.empty( - [chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32 + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, ) - cp_gather_indexer_k_quant_cache( + ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, - chunk.num_reqs, ) logits = fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale), + (k_fp8, k_scale.view(torch.float32)), weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke,