[DCP] Support dcp kv_cache interleave size > 1 (#26696)

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: Qiu <qiuchunshuo@huawei.com>
Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
zhangsicheng5 2025-11-09 03:45:27 +08:00 committed by GitHub
parent 47604137a2
commit 2108a571d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 202 additions and 79 deletions

View File

@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool
@ -52,6 +53,7 @@ class CPTestSettings:
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
@ -66,6 +68,7 @@ class CPTestSettings:
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
@ -108,6 +111,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
@ -180,6 +184,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
@ -207,6 +213,7 @@ CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),

View File

@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8
block_table = BlockTable(
block_size=block_size,
@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
)
# Verify hybrid block configuration

View File

@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)

View File

@ -227,6 +227,17 @@ class ParallelConfig:
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""
_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.

View File

@ -608,6 +608,23 @@ class VllmConfig:
)
current_platform.check_and_update_config(self)
assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be "
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()

View File

@ -385,6 +385,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
@ -750,6 +751,10 @@ class EngineArgs:
"-dcp",
**parallel_kwargs["decode_context_parallel_size"],
)
parallel_group.add_argument(
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
@ -1518,6 +1523,7 @@ class EngineArgs:
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)

View File

@ -43,6 +43,7 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -238,6 +239,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
@ -352,8 +357,12 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()

View File

@ -225,6 +225,7 @@ from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills,
@ -361,10 +362,9 @@ class MLACommonPrefillMetadata:
workspace: torch.Tensor
# for mla DCP
cp_chunk_seq_lens: list[list[int]] | None = None
origin_context_lens: list[int] | None = None
cp_cu_seq_lens: torch.Tensor | None = None
chunk_size: int | None = None
padded_local_chunk_seq_lens: list[list[int]] | None = None
local_context_lens_allranks: list[list[int]] | None = None
padded_local_cu_seq_lens: torch.Tensor | None = None
cu_seq_lens_lst: list[list[int]] | None = None
block_table: torch.Tensor
@ -568,6 +568,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
# Don't try to access the runner on AMD
if self.aot_schedule:
@ -794,15 +796,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
)
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank < seq_lens[:num_decodes] % self.dcp_world_size
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
@ -811,11 +804,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(
context_lens_cpu.float() / self.dcp_world_size
).int()
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = (
@ -871,32 +859,56 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
if self.dcp_world_size > 1:
local_context_lens_allranks = get_dcp_local_seq_lens(
context_lens_cpu,
self.dcp_world_size,
None,
self.dcp_local_block_size,
)
# Note(qcs): The max local context lengths
# padded to `dcp_local_block_size`.
padded_local_context_lens_cpu = (
cdiv(
context_lens_cpu,
self.dcp_virtual_block_size,
)
* self.dcp_local_block_size
)
# Note(hc): The above max_context_chunk already enforces
# block_size alignment, DCP just need the block_size can
# be divisible by dcp_world_size, because DCP use
# cp_gather_cache which not require `cp_chunk_starts`
# aligned to page_size.
assert max_context_chunk % self.dcp_world_size == 0
cp_max_context_chunk = max_context_chunk // self.dcp_world_size
cp_chunk_starts = (
padded_local_max_context_chunk_across_ranks = (
cdiv(
max_context_chunk,
self.dcp_virtual_block_size,
)
* self.dcp_local_block_size
)
local_chunk_starts = (
torch.arange(num_chunks, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, num_prefills)
* cp_max_context_chunk
* padded_local_max_context_chunk_across_ranks
)
cp_chunk_ends = torch.min(
cp_context_lens_cpu.unsqueeze(0),
cp_chunk_starts + cp_max_context_chunk,
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts
+ padded_local_max_context_chunk_across_ranks,
)
cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0)
padded_local_chunk_seq_lens = (
local_chunk_ends - local_chunk_starts
).clamp(min=0)
cp_cu_seq_lens_cpu = torch.zeros(
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
)
torch.cumsum(
cp_chunk_seq_lens,
padded_local_chunk_seq_lens,
dim=1,
out=cp_cu_seq_lens_cpu[:, 1:],
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
@ -908,15 +920,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if self.dcp_world_size > 1:
chunked_context_metadata = chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
starts=local_chunk_starts.to(device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
origin_context_lens=origin_context_lens,
cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True),
chunk_size=max_context_chunk,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
device, non_blocking=True
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
)
else:
@ -998,64 +1011,52 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
e.g.
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
T0_4, T0_5, pad, pad, T1_2, pad, ...]
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
T1_0, T1_1, T1_2, ...]
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
padded_local_chunk_seq_lens_lst: local chunk context lengths
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(
cp_chunk_seq_lens_lst, origin_context_lens
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks
):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx
)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
for rank, local_context_len in enumerate(local_context_lens):
if local_context_len != 0:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
+ local_context_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
+ local_context_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
cur_seq_len += local_context_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
@ -1296,6 +1297,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
get_current_vllm_config()
)
)
self.dcp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
)
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
@ -1697,10 +1701,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.origin_context_lens is not None
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
assert prefill_metadata.chunked_context.chunk_size is not None
assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
output = None
@ -1713,7 +1716,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
i
],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
@ -1743,15 +1748,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_normed, k_pe = reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[
padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
i
],
origin_context_lens=prefill_metadata.chunked_context.origin_context_lens,
cp_world_size=dcp_world_size,
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)

View File

@ -1076,3 +1076,41 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
return nums_dict, batch_ptr, token_chunk_offset_ptr
def get_dcp_local_seq_lens(
seq_lens: torch.Tensor,
dcp_world_size: int = 1,
dcp_rank: int | None = None,
dcp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank.
Only consider dcp now, we can extend the case of cp based on this.
"""
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_world_size, dtype=torch.int32)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)
base = (
seq_lens_tiled
// dcp_kv_cache_interleave_size
// dcp_world_size
* dcp_kv_cache_interleave_size
)
remainder = seq_lens_tiled - base * dcp_world_size
remainder = torch.clip(
remainder - rank_offsets * dcp_kv_cache_interleave_size,
0,
dcp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)

View File

@ -22,6 +22,7 @@ class BlockTable:
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
dcp_kv_cache_interleave_size: int,
):
"""
Args:
@ -86,6 +87,7 @@ class BlockTable:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
def append_row(
self,
@ -144,9 +146,19 @@ class BlockTable:
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
mask = (
virtual_block_offsets
// self.dcp_kv_cache_interleave_size
% self.dcp_world_size
== self.dcp_rank
)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
@ -234,6 +246,7 @@ class MultiGroupBlockTable:
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
@ -263,6 +276,7 @@ class MultiGroupBlockTable:
pin_memory,
device,
kernel_block_size,
dcp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]

View File

@ -84,6 +84,7 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@ -137,6 +138,7 @@ class InputBatch:
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
)
# Sampling-related.

View File

@ -35,6 +35,7 @@ from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group,
get_tp_group,
graph_capture,
@ -88,6 +89,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
split_attn_metadata,
)
@ -275,6 +277,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
@ -396,6 +399,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
@ -1307,6 +1311,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices
)
# update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.dcp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]