mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 09:37:14 +08:00
[bugfix] correct local_chunk_len for DCP in reorg_kvcache with long context (#28526)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
5d6ce2b960
commit
968060c15a
@ -337,6 +337,7 @@ class MLACommonPrefillMetadata:
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
padded_local_cu_seq_lens: torch.Tensor | None = None
|
||||
cu_seq_lens_lst: list[list[int]] | None = None
|
||||
chunk_size: int | None = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
@ -902,6 +903,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
device, non_blocking=True
|
||||
),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = chunked_context_metadata_cls(
|
||||
@ -986,6 +988,8 @@ def reorg_kvcache(
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -1001,6 +1005,9 @@ def reorg_kvcache(
|
||||
local_context_lens_allranks: local context lengths on each CP rank.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: the local padded max context chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
@ -1012,20 +1019,31 @@ def reorg_kvcache(
|
||||
):
|
||||
cur_seq_len = 0
|
||||
for rank, local_context_len in enumerate(local_context_lens):
|
||||
if local_context_len != 0:
|
||||
# Note(qcs): We split the context into multiple chunks,
|
||||
# depending on the size of the workspace.
|
||||
# local_context in dcp0: |-----------------|
|
||||
# local_context in dcp1: |--------------|
|
||||
# n*padded_local_chunk: |-----|-----|-----|
|
||||
# local_chunk_len in dcp1: |-----|-----|--|
|
||||
# so we need update the last chunk length in dcp1.
|
||||
local_chunk_len = min(
|
||||
max(0, local_context_len - chunk_idx * chunk_size),
|
||||
padded_local_chunk_seq_len,
|
||||
)
|
||||
if local_chunk_len != 0:
|
||||
kv_c_segment = allgatered_kv_c_normed[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ local_context_len
|
||||
+ local_chunk_len
|
||||
]
|
||||
k_pe_segment = allgatered_k_pe[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ local_context_len
|
||||
+ local_chunk_len
|
||||
]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += local_context_len
|
||||
cur_seq_len += local_chunk_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += padded_local_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
@ -1676,6 +1694,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
|
||||
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
@ -1725,6 +1744,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user