From 968060c15adc0b68a76d37db00acf1273a23b829 Mon Sep 17 00:00:00 2001 From: Qiu Date: Fri, 14 Nov 2025 03:29:22 +0800 Subject: [PATCH] [bugfix] correct local_chunk_len for DCP in reorg_kvcache with long context (#28526) Signed-off-by: QiuChunshuo Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/v1/attention/backends/mla/common.py | 29 ++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 467c01cd9d069..2ccdd1f143ce8 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -337,6 +337,7 @@ class MLACommonPrefillMetadata: local_context_lens_allranks: list[list[int]] | None = None padded_local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None + chunk_size: int | None = None block_table: torch.Tensor query_start_loc: torch.Tensor @@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): device, non_blocking=True ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + chunk_size=padded_local_max_context_chunk_across_ranks, ) else: chunked_context_metadata = chunked_context_metadata_cls( @@ -986,6 +988,8 @@ def reorg_kvcache( local_context_lens_allranks: list[list[int]], sum_seq_len: int, max_seq_len: int, + chunk_size: int, + chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -1001,6 +1005,9 @@ def reorg_kvcache( local_context_lens_allranks: local context lengths on each CP rank. sum_seq_len: the sum of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: the local padded max context chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. """ kv_c_segments = [] @@ -1012,20 +1019,31 @@ def reorg_kvcache( ): cur_seq_len = 0 for rank, local_context_len in enumerate(local_context_lens): - if local_context_len != 0: + # Note(qcs): We split the context into multiple chunks, + # depending on the size of the workspace. + # local_context in dcp0: |-----------------| + # local_context in dcp1: |--------------| + # n*padded_local_chunk: |-----|-----|-----| + # local_chunk_len in dcp1: |-----|-----|--| + # so we need update the last chunk length in dcp1. + local_chunk_len = min( + max(0, local_context_len - chunk_idx * chunk_size), + padded_local_chunk_seq_len, + ) + if local_chunk_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + local_context_len + + local_chunk_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + local_context_len + + local_chunk_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += local_context_len + cur_seq_len += local_chunk_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) src_token_idx += padded_local_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) @@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): assert prefill_metadata.chunked_context.local_context_lens_allranks is not None assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + assert prefill_metadata.chunked_context.chunk_size is not None output = None iters = len(prefill_metadata.chunked_context.seq_tot) @@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, toks=toks, )