[Deepseek v3.2] Support indexer prefill chunking (#25999)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Chen Zhang 2025-10-02 10:29:12 -07:00 committed by simon-mo
parent 9d9a2b77f1
commit c75c2e70d6
3 changed files with 149 additions and 79 deletions

View File

@ -22,6 +22,7 @@ from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
sdpa_reference,
rtol=0.5,
atol=0.5)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
assert out == expected

View File

@ -583,44 +583,43 @@ def sparse_attn_indexer(
topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
num_prefills = attn_metadata.num_prefills
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
prefill_metadata.block_table,
prefill_metadata.cu_seq_lens,
num_prefills,
)
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
num_tokens = attn_metadata.num_actual_tokens
logits = fp8_mqa_logits(
q_fp8[num_decode_tokens:num_tokens],
(k_fp8, k_scale),
weights[num_decode_tokens:num_tokens],
cu_seqlen_ks,
cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
chunk.cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode

View File

@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend):
@dataclass
class DeepseekV32IndexerPrefillMetadata:
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
# TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches(
start_seq_loc: torch.Tensor,
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of
@ -122,7 +128,7 @@ def kv_spans_from_batches(
are the **last** `counts[i]` positions of that sequence.
"""
q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
@ -130,7 +136,6 @@ def kv_spans_from_batches(
counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens
B = L.numel()
device = L.device
if N == 0:
return (torch.empty(0, dtype=torch.long, device=device),
@ -140,8 +145,7 @@ def kv_spans_from_batches(
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B, device=device),
counts) # [N]
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
# Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N]
@ -151,22 +155,51 @@ def kv_spans_from_batches(
L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b]
pos_within = (torch.arange(N, device=device, dtype=torch.long) -
pos_within = (torch.arange(N, dtype=torch.long) -
torch.repeat_interleave(q[:-1], counts) + 1)
local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end
return start_tensor.int(), end_location.int()
return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# max_num_batched_tokens = \
# vllm_config.scheduler_config.max_num_batched_tokens
max_num_seq = vllm_config.scheduler_config.max_num_seqs
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
return max_model_len * max_num_seq
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# May be tuned later.
return max_model_len * 2
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
max_prefill_buffer_size: int,
reqs_start: int) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype=torch.int32,
device=self.device)
def build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
device = self.device
block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc,
common_attn_metadata.seq_lens[reqs_start:])
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
assert total_seq_lens < self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32, device=device),
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
]).to(torch.int32).cuda()
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=common_attn_metadata.max_query_len,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0: