mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:15:24 +08:00
[DCP] Support dcp kv_cache interleave size > 1 (#26696)
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: Qiu <qiuchunshuo@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
parent
47604137a2
commit
2108a571d7
@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
dcp_size: int
|
||||
dcp_kv_cache_interleave_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
@ -52,6 +53,7 @@ class CPTestSettings:
|
||||
tp_base: int = 4,
|
||||
pp_base: int = 1,
|
||||
dcp_base: int = 1,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: str | None = None,
|
||||
@ -66,6 +68,7 @@ class CPTestSettings:
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=int(dcp_multiplier * tp_base),
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@ -108,6 +111,7 @@ def _compare_cp_with_tp(
|
||||
tp_size,
|
||||
pp_size,
|
||||
dcp_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
@ -180,6 +184,8 @@ def _compare_cp_with_tp(
|
||||
str(pp_size),
|
||||
"--decode-context-parallel-size",
|
||||
str(dcp_size),
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
str(dcp_kv_cache_interleave_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
@ -207,6 +213,7 @@ CP_TEXT_GENERATION_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
|
||||
],
|
||||
"bigcode/gpt_bigcode-santacoder": [
|
||||
CPTestSettings.detailed(),
|
||||
|
||||
@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
|
||||
max_num_reqs = 10
|
||||
max_num_blocks_per_req = 20
|
||||
max_num_batched_tokens = 512
|
||||
dcp_kv_cache_interleave_size = 8
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
|
||||
pin_memory=False,
|
||||
device=torch.device(DEVICE),
|
||||
kernel_block_size=kernel_block_sizes[0],
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Verify hybrid block configuration
|
||||
|
||||
@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
|
||||
lse = tl.load(lses_ptr + lse_offsets)
|
||||
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||
lse -= lse_max
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
|
||||
@ -227,6 +227,17 @@ class ParallelConfig:
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
dcp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using dcp or cp > 1,
|
||||
store interleave_size tokens on (d)cp i,
|
||||
then store next interleave_size tokens on (d)cp i+1.
|
||||
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
|
||||
Interleave_size=block_size: block-level align, first fill the block on first rank,
|
||||
token is stored on rank i+1 block j after rank i block j is full.
|
||||
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
|
||||
Block_size should be divisible by dcp_kv_cache_interleave_size.
|
||||
"""
|
||||
|
||||
_api_process_count: int = Field(default=1, gt=0)
|
||||
"""
|
||||
The number of API processes initialized.
|
||||
|
||||
@ -608,6 +608,23 @@ class VllmConfig:
|
||||
)
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
<= self.cache_config.block_size
|
||||
and self.cache_config.block_size
|
||||
% self.parallel_config.dcp_kv_cache_interleave_size
|
||||
== 0
|
||||
), (
|
||||
f"Block_size({self.cache_config.block_size}) should be "
|
||||
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size == 1
|
||||
or self.speculative_config is None
|
||||
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
|
||||
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
@ -385,6 +385,7 @@ class EngineArgs:
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: int | None = None
|
||||
data_parallel_start_rank: int | None = None
|
||||
@ -750,6 +751,10 @@ class EngineArgs:
|
||||
"-dcp",
|
||||
**parallel_kwargs["decode_context_parallel_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
||||
)
|
||||
@ -1518,6 +1523,7 @@ class EngineArgs:
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||
_api_process_count=self._api_process_count,
|
||||
_api_process_rank=self._api_process_rank,
|
||||
)
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -238,6 +239,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
self.dcp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
@ -352,8 +357,12 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
||||
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
|
||||
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
|
||||
|
||||
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
|
||||
dcp_context_kv_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||
|
||||
@ -225,6 +225,7 @@ from vllm.utils.math_utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_dcp_local_seq_lens,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills,
|
||||
@ -361,10 +362,9 @@ class MLACommonPrefillMetadata:
|
||||
workspace: torch.Tensor
|
||||
|
||||
# for mla DCP
|
||||
cp_chunk_seq_lens: list[list[int]] | None = None
|
||||
origin_context_lens: list[int] | None = None
|
||||
cp_cu_seq_lens: torch.Tensor | None = None
|
||||
chunk_size: int | None = None
|
||||
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
padded_local_cu_seq_lens: torch.Tensor | None = None
|
||||
cu_seq_lens_lst: list[list[int]] | None = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
@ -568,6 +568,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@ -794,15 +796,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
)
|
||||
)
|
||||
|
||||
# Note(hc): update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
assert dcp_local_seq_lens is not None
|
||||
dcp_local_seq_lens[:num_decodes] = seq_lens[
|
||||
:num_decodes
|
||||
] // self.dcp_world_size + (
|
||||
self.dcp_rank < seq_lens[:num_decodes] % self.dcp_world_size
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
@ -811,11 +804,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
# Note(hc): The context lengths in the perspective of dcp rank0.
|
||||
cp_context_lens_cpu = torch.ceil(
|
||||
context_lens_cpu.float() / self.dcp_world_size
|
||||
).int()
|
||||
origin_context_lens = context_lens_cpu.tolist()
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = (
|
||||
@ -871,32 +859,56 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
local_context_lens_allranks = get_dcp_local_seq_lens(
|
||||
context_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
None,
|
||||
self.dcp_local_block_size,
|
||||
)
|
||||
# Note(qcs): The max local context lengths
|
||||
# padded to `dcp_local_block_size`.
|
||||
padded_local_context_lens_cpu = (
|
||||
cdiv(
|
||||
context_lens_cpu,
|
||||
self.dcp_virtual_block_size,
|
||||
)
|
||||
* self.dcp_local_block_size
|
||||
)
|
||||
# Note(hc): The above max_context_chunk already enforces
|
||||
# block_size alignment, DCP just need the block_size can
|
||||
# be divisible by dcp_world_size, because DCP use
|
||||
# cp_gather_cache which not require `cp_chunk_starts`
|
||||
# aligned to page_size.
|
||||
assert max_context_chunk % self.dcp_world_size == 0
|
||||
cp_max_context_chunk = max_context_chunk // self.dcp_world_size
|
||||
cp_chunk_starts = (
|
||||
padded_local_max_context_chunk_across_ranks = (
|
||||
cdiv(
|
||||
max_context_chunk,
|
||||
self.dcp_virtual_block_size,
|
||||
)
|
||||
* self.dcp_local_block_size
|
||||
)
|
||||
local_chunk_starts = (
|
||||
torch.arange(num_chunks, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, num_prefills)
|
||||
* cp_max_context_chunk
|
||||
* padded_local_max_context_chunk_across_ranks
|
||||
)
|
||||
cp_chunk_ends = torch.min(
|
||||
cp_context_lens_cpu.unsqueeze(0),
|
||||
cp_chunk_starts + cp_max_context_chunk,
|
||||
local_chunk_ends = torch.min(
|
||||
padded_local_context_lens_cpu.unsqueeze(0),
|
||||
local_chunk_starts
|
||||
+ padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0)
|
||||
padded_local_chunk_seq_lens = (
|
||||
local_chunk_ends - local_chunk_starts
|
||||
).clamp(min=0)
|
||||
|
||||
cp_cu_seq_lens_cpu = torch.zeros(
|
||||
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
|
||||
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
torch.cumsum(
|
||||
cp_chunk_seq_lens,
|
||||
padded_local_chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cp_cu_seq_lens_cpu[:, 1:],
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
@ -908,15 +920,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
if self.dcp_world_size > 1:
|
||||
chunked_context_metadata = chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
starts=local_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
|
||||
origin_context_lens=origin_context_lens,
|
||||
cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
chunk_size=max_context_chunk,
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
)
|
||||
else:
|
||||
@ -998,64 +1011,52 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
def reorg_kvcache(
|
||||
allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
cp_chunk_seq_lens_lst: list[int],
|
||||
origin_context_lens: list[int],
|
||||
cp_world_size: int,
|
||||
padded_local_chunk_seq_lens_lst: list[int],
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
reorg kvcache after cp local gather to tp layout for attn kernel.
|
||||
|
||||
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
|
||||
e.g.
|
||||
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
|
||||
T0_4, T0_5, pad, pad, T1_2, pad, ...]
|
||||
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
|
||||
T1_0, T1_1, T1_2, ...]
|
||||
Args:
|
||||
cp_chunk_seq_lens_lst: chunk context lengths under CP.
|
||||
origin_context_lens: origin full context lengths under CP.
|
||||
cp_world_size: CP size.
|
||||
padded_local_chunk_seq_lens_lst: local chunk context lengths
|
||||
under current CP rank.
|
||||
local_context_lens_allranks: local context lengths on each CP rank.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: equals to max_context_chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
k_pe_segments = []
|
||||
src_token_idx = 0
|
||||
max_seq_len_check = 0
|
||||
for cp_chunk_seq_len, origin_context_len in zip(
|
||||
cp_chunk_seq_lens_lst, origin_context_lens
|
||||
for padded_local_chunk_seq_len, local_context_lens in zip(
|
||||
padded_local_chunk_seq_lens_lst, local_context_lens_allranks
|
||||
):
|
||||
chunk_context_len = chunk_size
|
||||
if cp_chunk_seq_len != 0:
|
||||
chunk_context_len = min(
|
||||
chunk_context_len, origin_context_len - chunk_size * chunk_idx
|
||||
)
|
||||
cp_target_rank = (chunk_context_len - 1) % cp_world_size
|
||||
cur_seq_len = 0
|
||||
for rank in range(cp_world_size):
|
||||
if rank > cp_target_rank and cp_chunk_seq_len:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
|
||||
else:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len
|
||||
if real_cp_chunk_seq_len:
|
||||
for rank, local_context_len in enumerate(local_context_lens):
|
||||
if local_context_len != 0:
|
||||
kv_c_segment = allgatered_kv_c_normed[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ real_cp_chunk_seq_len
|
||||
+ local_context_len
|
||||
]
|
||||
k_pe_segment = allgatered_k_pe[
|
||||
rank * toks + src_token_idx : rank * toks
|
||||
+ src_token_idx
|
||||
+ real_cp_chunk_seq_len
|
||||
+ local_context_len
|
||||
]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += real_cp_chunk_seq_len
|
||||
cur_seq_len += local_context_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += cp_chunk_seq_len
|
||||
src_token_idx += padded_local_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||
@ -1296,6 +1297,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self.dcp_kv_cache_interleave_size: int = (
|
||||
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
@ -1697,10 +1701,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.origin_context_lens is not None
|
||||
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||
assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
|
||||
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||
|
||||
output = None
|
||||
@ -1713,7 +1716,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
|
||||
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
|
||||
i
|
||||
],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
@ -1743,15 +1748,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
kv_c_normed, k_pe = reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[
|
||||
padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
|
||||
i
|
||||
],
|
||||
origin_context_lens=prefill_metadata.chunked_context.origin_context_lens,
|
||||
cp_world_size=dcp_world_size,
|
||||
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
|
||||
|
||||
@ -1076,3 +1076,41 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
|
||||
|
||||
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
||||
|
||||
|
||||
def get_dcp_local_seq_lens(
|
||||
seq_lens: torch.Tensor,
|
||||
dcp_world_size: int = 1,
|
||||
dcp_rank: int | None = None,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each dcp rank.
|
||||
Only consider dcp now, we can extend the case of cp based on this.
|
||||
"""
|
||||
num_requests = seq_lens.size(0)
|
||||
if dcp_rank is None:
|
||||
rank_offsets = (
|
||||
torch.arange(dcp_world_size, dtype=torch.int32)
|
||||
.unsqueeze(0)
|
||||
.repeat(num_requests, 1)
|
||||
)
|
||||
else:
|
||||
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
|
||||
seq_lens_tiled = (
|
||||
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
|
||||
)
|
||||
base = (
|
||||
seq_lens_tiled
|
||||
// dcp_kv_cache_interleave_size
|
||||
// dcp_world_size
|
||||
* dcp_kv_cache_interleave_size
|
||||
)
|
||||
remainder = seq_lens_tiled - base * dcp_world_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * dcp_kv_cache_interleave_size,
|
||||
0,
|
||||
dcp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = base + remainder
|
||||
return dcp_local_seq_lens.squeeze(1)
|
||||
|
||||
@ -22,6 +22,7 @@ class BlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
dcp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -86,6 +87,7 @@ class BlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@ -144,9 +146,19 @@ class BlockTable:
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.dcp_kv_cache_interleave_size
|
||||
% self.dcp_world_size
|
||||
== self.dcp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
||||
* self.dcp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
@ -234,6 +246,7 @@ class MultiGroupBlockTable:
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@ -263,6 +276,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
@ -84,6 +84,7 @@ class InputBatch:
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@ -137,6 +138,7 @@ class InputBatch:
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
graph_capture,
|
||||
@ -88,6 +89,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
get_dcp_local_seq_lens,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
split_attn_metadata,
|
||||
)
|
||||
@ -275,6 +277,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
@ -396,6 +399,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# uses output token ids so we set this conservatively.
|
||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
@ -1307,6 +1311,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits_indices
|
||||
)
|
||||
|
||||
# update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
|
||||
self.seq_lens.cpu[:num_reqs],
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.parallel_config.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user