[bugfix] correct local_chunk_len for DCP in reorg_kvcache with long context (#28526)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Qiu 2025-11-14 03:29:22 +08:00 committed by GitHub
parent 5d6ce2b960
commit 968060c15a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -337,6 +337,7 @@ class MLACommonPrefillMetadata:
local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None
chunk_size: int | None = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device, non_blocking=True
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = chunked_context_metadata_cls(
@ -986,6 +988,8 @@ def reorg_kvcache(
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@ -1001,6 +1005,9 @@ def reorg_kvcache(
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
@ -1012,20 +1019,31 @@ def reorg_kvcache(
):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
if local_context_len != 0:
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ local_context_len
+ local_chunk_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ local_context_len
+ local_chunk_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_context_len
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
assert prefill_metadata.chunked_context.chunk_size is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)