From 77e958752b67068c0417e08db6ac920ffc8238ab Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 2 Oct 2025 10:29:12 -0700 Subject: [PATCH] [Deepseek v3.2] Support indexer prefill chunking (#25999) Signed-off-by: Chen Zhang Signed-off-by: yewentao256 --- .../v1/attention/test_sparse_mla_backends.py | 22 +++ vllm/model_executor/models/deepseek_v2.py | 75 +++++----- vllm/v1/attention/backends/mla/indexer.py | 131 ++++++++++++------ 3 files changed, 149 insertions(+), 79 deletions(-) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 74eea6f716fef..ddad9342fad0d 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -22,6 +22,7 @@ from vllm.utils import cdiv from vllm.v1.attention.backends.mla.flashmla_sparse import ( FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata, FlashMLASparseImpl, FlashMLASparseMetadata) +from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, sdpa_reference, rtol=0.5, atol=0.5) + + +@pytest.mark.parametrize( + "seq_lens,max_buf,start,expected", + [ + # Basic split: totals per chunk ≤ max_buf + (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]), + # Non-zero start index + (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]), + # Exact fits should split between items when adding the next would + # overflow + (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]), + # All requests fit in a single chunk + (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]), + # Large buffer with non-zero start + (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]), + ], +) +def test_split_prefill_chunks(seq_lens, max_buf, start, expected): + out = split_prefill_chunks(seq_lens, max_buf, start) + assert out == expected diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 03c43654d68f1..b7f96d0d1552e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -583,44 +583,43 @@ def sparse_attn_indexer( topk_indices_buffer[:hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill - num_prefills = attn_metadata.num_prefills - k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim], - device=k.device, - dtype=torch.float8_e4m3fn) - k_scale = torch.empty([prefill_metadata.total_seq_lens, 1], - device=k.device, - dtype=torch.float32) - cp_gather_indexer_k_quant_cache( - kv_cache, - k_fp8, - k_scale, - prefill_metadata.block_table, - prefill_metadata.cu_seq_lens, - num_prefills, - ) - cu_seqlen_ks = prefill_metadata.cu_seqlen_ks - cu_seqlen_ke = prefill_metadata.cu_seqlen_ke - num_tokens = attn_metadata.num_actual_tokens - logits = fp8_mqa_logits( - q_fp8[num_decode_tokens:num_tokens], - (k_fp8, k_scale), - weights[num_decode_tokens:num_tokens], - cu_seqlen_ks, - cu_seqlen_ke, - ) - topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), - dim=-1)[1] - topk_indices -= cu_seqlen_ks[:, None] - mask_lo = topk_indices >= 0 - mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 - mask = torch.full_like(topk_indices, - False, - dtype=torch.bool, - device=topk_indices.device) - mask = mask_lo & mask_hi - topk_indices = topk_indices.masked_fill(~mask, -1) - topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + for chunk in prefill_metadata.chunks: + k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn) + k_scale = torch.empty([chunk.total_seq_lens, 1], + device=k.device, + dtype=torch.float32) + cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + chunk.num_reqs, + ) + logits = fp8_mqa_logits( + q_fp8[chunk.token_start:chunk.token_end], + (k_fp8, k_scale), + weights[chunk.token_start:chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), + dim=-1)[1] + topk_indices -= chunk.cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (chunk.cu_seqlen_ke - + chunk.cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like(topk_indices, + False, + dtype=torch.bool, + device=topk_indices.device) + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[ + chunk.token_start:chunk.token_end, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index e87b51b15191f..94b963f34e4a3 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend): @dataclass -class DeepseekV32IndexerPrefillMetadata: +class DeepseekV32IndexerPrefillChunkMetadata: block_table: torch.Tensor - query_start_loc: torch.Tensor - max_query_len: int cu_seqlen_ks: torch.Tensor cu_seqlen_ke: torch.Tensor cu_seq_lens: torch.Tensor total_seq_lens: int + token_start: int + token_end: int + num_reqs: int + + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + chunks: list[DeepseekV32IndexerPrefillChunkMetadata] @dataclass @@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata: # TODO (zyongye) optimize this, this is now vibe coded def kv_spans_from_batches( - start_seq_loc: torch.Tensor, - seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, + device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: """ Args: start_seq_loc: 1D long tensor [B+1], cumulative counts of @@ -122,7 +128,7 @@ def kv_spans_from_batches( are the **last** `counts[i]` positions of that sequence. """ q = start_seq_loc.to(dtype=torch.long) - L = seq_len_per_batch.to(dtype=torch.long, device=q.device) + L = seq_len_per_batch.to(dtype=torch.long) assert q.dim() == 1 and L.dim() == 1 assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" @@ -130,7 +136,6 @@ def kv_spans_from_batches( counts = q[1:] - q[:-1] # [B] N = int(q[-1].item()) # total selected tokens B = L.numel() - device = L.device if N == 0: return (torch.empty(0, dtype=torch.long, device=device), @@ -140,8 +145,7 @@ def kv_spans_from_batches( kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] # For each selected token, which batch does it belong to? - batch_id = torch.repeat_interleave(torch.arange(B, device=device), - counts) # [N] + batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N] # Map batch KV start to each token start_tensor = kv_starts_per_batch[batch_id] # [N] @@ -151,22 +155,51 @@ def kv_spans_from_batches( L_expand = torch.repeat_interleave(L, counts) # [N] m_expand = torch.repeat_interleave(counts, counts) # [N] # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, device=device, dtype=torch.long) - + pos_within = (torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1) local_pos = L_expand - m_expand + pos_within # [N], 1-based end_location = start_tensor + local_pos # exclusive end - return start_tensor.int(), end_location.int() + return start_tensor.int().to(device), end_location.int().to(device) def get_max_prefill_buffer_size(vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len - # max_num_batched_tokens = \ - # vllm_config.scheduler_config.max_num_batched_tokens - max_num_seq = vllm_config.scheduler_config.max_num_seqs - # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. - return max_model_len * max_num_seq + # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. + # May be tuned later. + return max_model_len * 2 + + +def split_prefill_chunks(seq_lens_cpu: torch.Tensor, + max_prefill_buffer_size: int, + reqs_start: int) -> list[tuple[int, int]]: + """ + Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) + such that the total sequence length of each chunk is less than the + maximum prefill buffer size. + + Args: + seq_lens_cpu: The sequence lengths of the prefill requests. + max_prefill_buffer_size: The maximum prefill buffer size. + reqs_start: The start index of the prefill requests. + + Returns: + A list of tuples of (reqs_start, reqs_end). + """ + chunk_seq_ids = [] + total_seq_lens = 0 + for i in range(reqs_start, len(seq_lens_cpu)): + cur_seq_len = seq_lens_cpu[i].item() + assert cur_seq_len <= max_prefill_buffer_size + total_seq_lens += cur_seq_len + if total_seq_lens > max_prefill_buffer_size: + chunk_seq_ids.append((reqs_start, i)) + reqs_start = i + total_seq_lens = cur_seq_len + if total_seq_lens > 0: + chunk_seq_ids.append((reqs_start, len(seq_lens_cpu))) + return chunk_seq_ids class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): @@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): dtype=torch.int32, device=self.device) + def build_one_prefill_chunk(self, reqs_start, reqs_end, + query_start_loc_cpu, seq_lens_cpu, + block_table): + prefill_query_start_loc = query_start_loc_cpu[ + reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], + self.device) + token_start = query_start_loc_cpu[reqs_start].item() + token_end = query_start_loc_cpu[reqs_end].item() + total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() + assert total_seq_lens <= self.max_prefill_buffer_size + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) + ]).to(torch.int32).to(self.device) + return DeepseekV32IndexerPrefillChunkMetadata( + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + block_table=block_table[reqs_start:reqs_end], + token_start=token_start, + token_end=token_end, + num_reqs=reqs_end - reqs_start, + ) + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens - device = self.device - block_table_tensor = common_attn_metadata.block_table_tensor - - query_start_loc = common_attn_metadata.query_start_loc - + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills( common_attn_metadata, @@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): prefill_metadata = None if num_prefills > 0: - reqs_start = num_decodes - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] - cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, - common_attn_metadata.seq_lens[reqs_start:]) - total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum() - assert total_seq_lens < self.max_prefill_buffer_size - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32, device=device), - common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0) - ]).to(torch.int32).cuda() - prefill_metadata = DeepseekV32IndexerPrefillMetadata( - block_table=block_table_tensor[reqs_start:, ...], - query_start_loc=prefill_query_start_loc, - max_query_len=common_attn_metadata.max_query_len, - cu_seqlen_ks=cu_seqlen_ks, - cu_seqlen_ke=cu_seqlen_ke, - cu_seq_lens=cu_seq_lens, - total_seq_lens=total_seq_lens, + chunk_seq_ids = split_prefill_chunks( + common_attn_metadata.seq_lens_cpu, + self.max_prefill_buffer_size, + num_decodes, ) + chunks = [ + self.build_one_prefill_chunk( + reqs_start, reqs_end, query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu, + common_attn_metadata.block_table_tensor) + for reqs_start, reqs_end in chunk_seq_ids + ] + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + chunks=chunks, ) decode_metadata = None if num_decodes > 0: