[Feature] Prefill Context Parallel (PCP) basic support (#28718)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: LookAround <lixushi@huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
Qiu 2025-11-20 04:52:44 +08:00 committed by GitHub
parent 02f5903b84
commit 2fd893b4ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 399 additions and 114 deletions

View File

@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
tp_size: int tp_size: int
pp_size: int pp_size: int
dcp_size: int dcp_size: int
dcp_kv_cache_interleave_size: int cp_kv_cache_interleave_size: int
eager_mode: bool eager_mode: bool
chunked_prefill: bool chunked_prefill: bool
@ -55,7 +55,7 @@ class CPTestSettings:
tp_base: int = 4, tp_base: int = 4,
pp_base: int = 1, pp_base: int = 1,
dcp_base: int = 1, dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False, multi_node_only: bool = False,
runner: RunnerOption = "auto", runner: RunnerOption = "auto",
load_format: str | None = None, load_format: str | None = None,
@ -71,7 +71,7 @@ class CPTestSettings:
tp_size=tp_base, tp_size=tp_base,
pp_size=pp_multiplier * pp_base, pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base), dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
eager_mode=eager_mode_val, eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val, chunked_prefill=chunked_prefill_val,
) )
@ -116,7 +116,7 @@ def _compare_cp_with_tp(
tp_size, tp_size,
pp_size, pp_size,
dcp_size, dcp_size,
dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size,
eager_mode, eager_mode,
chunked_prefill, chunked_prefill,
) = parallel_setup ) = parallel_setup
@ -197,7 +197,7 @@ def _compare_cp_with_tp(
"--decode-context-parallel-size", "--decode-context-parallel-size",
str(dcp_size), str(dcp_size),
"--dcp-kv-cache-interleave-size", "--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size), str(cp_kv_cache_interleave_size),
"--distributed-executor-backend", "--distributed-executor-backend",
distributed_backend, distributed_backend,
] ]
@ -227,7 +227,7 @@ CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [ "deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(), CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2), CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64), CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
], ],
"bigcode/gpt_bigcode-santacoder": [ "bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(), CPTestSettings.detailed(),

View File

@ -15,7 +15,11 @@ from tests.kernels.quantization.nvfp4_utils import (
) )
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
@ -561,6 +565,7 @@ def make_modular_kernel(
# make moe config # make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(), tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config, vllm_parallel_config=vllm_config.parallel_config,
) )

View File

@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10 max_num_reqs = 10
max_num_blocks_per_req = 20 max_num_blocks_per_req = 20
max_num_batched_tokens = 512 max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8 cp_kv_cache_interleave_size = 8
block_table = BlockTable( block_table = BlockTable(
block_size=block_size, block_size=block_size,
@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False, pin_memory=False,
device=torch.device(DEVICE), device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0], kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
) )
# Verify hybrid block configuration # Verify hybrid block configuration

View File

@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int dcp_world_size: int
dcp_rank: int dcp_rank: int
pcp_world_size: int
pcp_rank: int
total_cp_world_size: int
total_cp_rank: int
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this # use __new__ so that all subclasses will call this
self = super().__new__(cls) self = super().__new__(cls)
@ -278,6 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
# DCP might not be initialized in testing # DCP might not be initialized in testing
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
self.need_to_return_lse_for_decode = ( self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode self.dcp_world_size > 1 and self.can_return_lse_for_decode
) )

View File

@ -169,12 +169,11 @@ def correct_attn_out(
return out, lse return out, lse
def cp_lse_ag_out_rs( def _cp_lse_common(
cp_attn_out: torch.Tensor, cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor, cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator, cp_group: GroupCoordinator,
ctx: CPTritonContext = None, ctx: CPTritonContext | None = None,
return_lse=False,
): ):
""" """
cp_attn_out: [ B, H, D ] cp_attn_out: [ B, H, D ]
@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous() cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
return out, lse
def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1) out = cp_group.reduce_scatter(out, dim=1)
if return_lse: if return_lse:
@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
return out return out
def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.all_reduce(out)
if return_lse:
return out, lse
return out
@triton.jit @triton.jit
def _pack_seq_kernel( def _pack_seq_kernel(
x_ptr, # [N, D] x_ptr, # [N, D]

View File

@ -71,6 +71,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups.""" """Number of pipeline parallel groups."""
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
"""Number of tensor parallel groups.""" """Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1 data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to """Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size.""" the product of the tensor parallel size and data parallel size."""
@ -239,14 +241,25 @@ class ParallelConfig:
needs to be divisible by dcp_size.""" needs to be divisible by dcp_size."""
dcp_kv_cache_interleave_size: int = 1 dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1, """
store interleave_size tokens on (d)cp i, Interleave size of kv_cache storage while using DCP.
then store next interleave_size tokens on (d)cp i+1. dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size. and will be deprecated when PCP is fully supported.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full. """
Block_size should be greater than or equal to dcp_kv_cache_interleave_size. cp_kv_cache_interleave_size: int = 1
Block_size should be divisible by dcp_kv_cache_interleave_size. """Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
store interleave_size tokens on total_cp_rank i,
then store next interleave_size tokens on taotal_cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
first populated to the preceding ranks. Tokens are then stored
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
Block_size should be divisible by cp_kv_cache_interleave_size.
""" """
_api_process_count: int = Field(default=1, gt=0) _api_process_count: int = Field(default=1, gt=0)
@ -311,6 +324,11 @@ class ParallelConfig:
"num_redundant_experts." "num_redundant_experts."
) )
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self return self
@property @property
@ -529,7 +547,11 @@ class ParallelConfig:
) )
# Continue with the rest of the initialization # Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.") logger.info("Using external launcher for distributed inference.")

View File

@ -481,6 +481,14 @@ class VllmConfig:
"Overriding cudagraph_mode to PIECEWISE." "Overriding cudagraph_mode to PIECEWISE."
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif self.parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None: elif self.model_config is not None:
if self.model_config.pooler_config is not None: if self.model_config.pooler_config is not None:
logger.warning_once( logger.warning_once(
@ -610,22 +618,34 @@ class VllmConfig:
# If DCP, ensure the block size is right. # If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1: if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert ( assert (
self.parallel_config.dcp_kv_cache_interleave_size self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size <= self.cache_config.block_size
and self.cache_config.block_size and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size % self.parallel_config.cp_kv_cache_interleave_size
== 0 == 0
), ( ), (
f"Block_size({self.cache_config.block_size}) should be greater " f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by dcp_kv_cache_interleave_size " "than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})." f"({self.parallel_config.cp_kv_cache_interleave_size})."
) )
assert ( assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1 self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now." ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:

View File

@ -1098,6 +1098,12 @@ get_context_model_parallel_group = get_dcp_group
_PP: GroupCoordinator | None = None _PP: GroupCoordinator | None = None
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
_DP: GroupCoordinator | None = None _DP: GroupCoordinator | None = None
@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator:
return _EP return _EP
def get_pp_group() -> GroupCoordinator: _PCP: GroupCoordinator | None = None
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
def get_pcp_group() -> GroupCoordinator:
assert _PCP is not None, "prefill context parallel group is not initialized"
return _PCP
@deprecated( @deprecated(
@ -1276,6 +1285,7 @@ def init_distributed_environment(
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1, decode_context_model_parallel_size: int | None = 1,
backend: str | None = None, backend: str | None = None,
) -> None: ) -> None:
@ -1325,7 +1335,11 @@ def initialize_model_parallel(
# to get group_ranks for each dimension, transpose that dimension to the # to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension # last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape( all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size -1,
data_parallel_size,
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
) # noqa ) # noqa
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
@ -1360,11 +1374,23 @@ def initialize_model_parallel(
group_name="dcp", group_name="dcp",
) )
global _PCP
assert _PCP is None, "prefill context parallel group is already initialized"
group_ranks = (
all_ranks.transpose(3, 4)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PP global _PP
assert _PP is None, "pipeline model parallel group is already initialized" assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = ( group_ranks = (
all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group( _PP = init_model_parallel_group(
@ -1373,7 +1399,7 @@ def initialize_model_parallel(
global _DP global _DP
assert _DP is None, "data parallel group is already initialized" assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group( _DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp" group_ranks, get_world_group().local_rank, backend, group_name="dp"
@ -1383,7 +1409,12 @@ def initialize_model_parallel(
assert _EP is None, "expert parallel group is already initialized" assert _EP is None, "expert parallel group is already initialized"
group_ranks = ( group_ranks = (
all_ranks.transpose(1, 2) all_ranks.transpose(1, 2)
.reshape(-1, data_parallel_size * tensor_model_parallel_size) .reshape(
-1,
data_parallel_size
* prefill_context_model_parallel_size
* tensor_model_parallel_size,
)
.unbind(0) .unbind(0)
) )
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
@ -1393,11 +1424,13 @@ def initialize_model_parallel(
logger.info_once( logger.info_once(
"rank %s in world size %s is assigned as " "rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", "DP rank %s, PP rank %s, PCP rank %s, "
"TP rank %s, EP rank %s",
rank, rank,
world_size, world_size,
_DP.rank_in_group, _DP.rank_in_group,
_PP.rank_in_group, _PP.rank_in_group,
_PCP.rank_in_group,
_TP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group, _EP.rank_in_group,
) )
@ -1406,6 +1439,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1, decode_context_model_parallel_size: int | None = 1,
backend: str | None = None, backend: str | None = None,
) -> None: ) -> None:
@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized(
initialize_model_parallel( initialize_model_parallel(
tensor_model_parallel_size, tensor_model_parallel_size,
pipeline_model_parallel_size, pipeline_model_parallel_size,
prefill_context_model_parallel_size,
decode_context_model_parallel_size, decode_context_model_parallel_size,
backend, backend,
) )
@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized(
f"got: {pp_world_size=} vs. " f"got: {pp_world_size=} vs. "
f"wanted: {pipeline_model_parallel_size=}" f"wanted: {pipeline_model_parallel_size=}"
) )
pcp_world_size = get_pcp_group().world_size
assert pcp_world_size == prefill_context_model_parallel_size, (
"prefill context parallel group already initialized, but of unexpected size: "
f"{pcp_world_size=} vs. "
f"{prefill_context_model_parallel_size=}"
)
def prepare_communication_buffer_for_model(model: torch.nn.Module): def prepare_communication_buffer_for_model(model: torch.nn.Module):
@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
""" """
if _TP is not None: if _TP is not None:
_TP.prepare_communication_buffer_for_model(model) _TP.prepare_communication_buffer_for_model(model)
if _PCP is not None:
_PCP.prepare_communication_buffer_for_model(model)
if _PP is not None: if _PP is not None:
_PP.prepare_communication_buffer_for_model(model) _PP.prepare_communication_buffer_for_model(model)
if _DP is not None: if _DP is not None:
@ -1520,16 +1563,21 @@ def destroy_model_parallel():
_TP.destroy() _TP.destroy()
_TP = None _TP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
global _DCP global _DCP
if _DCP: if _DCP:
_DCP.destroy() _DCP.destroy()
_DCP = None _DCP = None
global _PCP
if _PCP:
_PCP.destroy()
_PCP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
global _DP global _DP
if _DP: if _DP:
_DP.destroy() _DP.destroy()

View File

@ -389,8 +389,10 @@ class EngineArgs:
nnodes: int = ParallelConfig.nnodes nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank node_rank: int = ParallelConfig.node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None data_parallel_start_rank: int | None = None
@ -770,6 +772,15 @@ class EngineArgs:
"--dcp-kv-cache-interleave-size", "--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"], **parallel_kwargs["dcp_kv_cache_interleave_size"],
) )
parallel_group.add_argument(
"--cp-kv-cache-interleave-size",
**parallel_kwargs["cp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--prefill-context-parallel-size",
"-pcp",
**parallel_kwargs["prefill_context_parallel_size"],
)
parallel_group.add_argument( parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
) )
@ -1600,6 +1611,7 @@ class EngineArgs:
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
prefill_context_parallel_size=self.prefill_context_parallel_size,
data_parallel_size=self.data_parallel_size, data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0, data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb, data_parallel_external_lb=data_parallel_external_lb,
@ -1631,6 +1643,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size, decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count, _api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank, _api_process_rank=self._api_process_rank,
) )
@ -1952,6 +1965,15 @@ class EngineArgs:
default_prefix_caching, default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config) ) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
logger.warning(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default."
)
if self.enable_chunked_prefill is None: if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill self.enable_chunked_prefill = default_chunked_prefill

View File

@ -8,7 +8,11 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES, OCP_MX_DTYPES,
@ -684,9 +688,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass @dataclass
class FusedMoEParallelConfig: class FusedMoEParallelConfig:
tp_size: int tp_size: int
pcp_size: int
dp_size: int dp_size: int
ep_size: int ep_size: int
tp_rank: int tp_rank: int
pcp_rank: int
dp_rank: int dp_rank: int
ep_rank: int ep_rank: int
@ -713,19 +719,22 @@ class FusedMoEParallelConfig:
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod @staticmethod
def flatten_tp_across_dp( def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
) -> tuple[int, int]: ) -> tuple[int, int]:
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank() tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size * tp_size devices. Update tp_size # There are actually dp_size * pcp_size * tp_size devices.
# and tp_rank so we shard across all devices. # Update tp_size and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * tp_size flatten_tp_size = dp_size * pcp_size * tp_size
flatten_tp_rank = dp_rank * tp_size + tp_rank flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
return flatten_tp_size, flatten_tp_rank return flatten_tp_size, flatten_tp_rank
@staticmethod @staticmethod
def make( def make(
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig tp_size_: int,
pcp_size_: int,
dp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig": ) -> "FusedMoEParallelConfig":
""" """
Determine MoE parallel configuration. Based on the input `tp_size_`, Determine MoE parallel configuration. Based on the input `tp_size_`,
@ -734,19 +743,22 @@ class FusedMoEParallelConfig:
Args: Args:
tp_size_ (int): `tp_size` passed into the FusedMoE constructor. tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag. object which contains the `enable_expert_parallel` flag.
Examples: Examples:
When there is no parallelism requested, When there is no parallelism requested,
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
unaltered and the ranks set to 0. unaltered and the ranks set to 0.
Expert Parallelism is considered only when either `dp_size_` or Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
`tp_size_` is non trivial. `tp_size_` is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different Note that PCP serves the same function as DP here.
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
devices: devices:
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
@ -754,7 +766,7 @@ class FusedMoEParallelConfig:
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices. - Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
devices: devices:
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
@ -762,7 +774,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded - Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices. across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
devices: devices:
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
@ -772,14 +784,14 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded - Comment: There are 2 engine instances and the tensors are sharded
across 4 devices. across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
devices: devices:
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices. - Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
devices: devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
@ -787,7 +799,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split - Comment: There are 2 engine instances and the experts are split
between the 2 devices. between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
devices: devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
@ -798,18 +810,25 @@ class FusedMoEParallelConfig:
between the 4 devices. between the 4 devices.
""" """
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel use_ep = (
dp_size_ * pcp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel
)
dp_size = dp_size_ dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( pcp_size = pcp_size_
tp_size_, dp_size_, dp_rank pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
) )
if not use_ep: if not use_ep:
return FusedMoEParallelConfig( return FusedMoEParallelConfig(
tp_size=tp_size, tp_size=tp_size,
tp_rank=tp_rank, tp_rank=tp_rank,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size, dp_size=dp_size,
dp_rank=dp_rank, dp_rank=dp_rank,
ep_size=1, ep_size=1,
@ -826,6 +845,8 @@ class FusedMoEParallelConfig:
return FusedMoEParallelConfig( return FusedMoEParallelConfig(
tp_size=1, tp_size=1,
tp_rank=0, tp_rank=0,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size, dp_size=dp_size,
dp_rank=dp_rank, dp_rank=dp_rank,
ep_size=ep_size, ep_size=ep_size,

View File

@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import ( from vllm.distributed import (
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
get_pcp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
@ -343,6 +344,7 @@ class FusedMoE(CustomOp):
tp_size: int | None = None, tp_size: int | None = None,
ep_size: int | None = None, ep_size: int | None = None,
dp_size: int | None = None, dp_size: int | None = None,
pcp_size: int | None = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Callable | None = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
@ -398,12 +400,14 @@ class FusedMoE(CustomOp):
tp_size if tp_size is not None else get_tensor_model_parallel_world_size() tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
) )
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
self.is_sequence_parallel = is_sequence_parallel self.is_sequence_parallel = is_sequence_parallel
self.sp_size = tp_size_ if is_sequence_parallel else 1 self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=tp_size_, tp_size_=tp_size_,
pcp_size_=pcp_size_,
dp_size_=dp_size_, dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config, vllm_parallel_config=vllm_config.parallel_config,
) )
@ -679,6 +683,10 @@ class FusedMoE(CustomOp):
def dp_size(self): def dp_size(self):
return self.moe_parallel_config.dp_size return self.moe_parallel_config.dp_size
@property
def pcp_size(self):
return self.moe_parallel_config.pcp_size
@property @property
def ep_size(self): def ep_size(self):
return self.moe_parallel_config.ep_size return self.moe_parallel_config.ep_size
@ -691,6 +699,10 @@ class FusedMoE(CustomOp):
def dp_rank(self): def dp_rank(self):
return self.moe_parallel_config.dp_rank return self.moe_parallel_config.dp_rank
@property
def pcp_rank(self):
return self.moe_parallel_config.pcp_rank
@property @property
def ep_rank(self): def ep_rank(self):
return self.moe_parallel_config.ep_rank return self.moe_parallel_config.ep_rank
@ -1871,6 +1883,19 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
@ -1925,6 +1950,13 @@ class FusedMoE(CustomOp):
def combine_output(states: torch.Tensor) -> torch.Tensor: def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel) states = get_ep_group().combine(states, self.is_sequence_parallel)
if self.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states return states
if self.shared_experts is not None: if self.shared_experts is not None:

View File

@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
get_pcp_group,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@ -322,10 +323,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data # In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled. # parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(), tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size, dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group, dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
) )
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size
@ -507,10 +510,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data # In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled. # parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(), tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size, dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group, dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
) )
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size

View File

@ -265,8 +265,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = ( self.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size self.parallel_config.cp_kv_cache_interleave_size
) )
self.use_full_cuda_graph = ( self.use_full_cuda_graph = (
@ -388,7 +388,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens_cpu, dcp_context_kv_lens_cpu,
self.dcp_world_size, self.dcp_world_size,
self.dcp_rank, self.dcp_rank,
self.dcp_kv_cache_interleave_size, self.cp_kv_cache_interleave_size,
) )
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item() max_dcp_context_kv_len = dcp_context_kv_lens.max().item()

View File

@ -536,7 +536,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# DCP might not be initialized in testing # DCP might not be initialized in testing
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
# Don't try to access the runner on AMD # Don't try to access the runner on AMD
@ -1289,8 +1289,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
get_current_vllm_config() get_current_vllm_config()
) )
) )
self.dcp_kv_cache_interleave_size: int = ( self.cp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
) )
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(

View File

@ -1080,9 +1080,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
def get_dcp_local_seq_lens( def get_dcp_local_seq_lens(
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
dcp_world_size: int = 1, dcp_size: int = 1,
dcp_rank: int | None = None, dcp_rank: int | None = None,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor: ) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different, """While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank. use this function to calculate split decode seq_lens of each dcp rank.
@ -1091,7 +1091,7 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0) num_requests = seq_lens.size(0)
if dcp_rank is None: if dcp_rank is None:
rank_offsets = ( rank_offsets = (
torch.arange(dcp_world_size, dtype=torch.int32) torch.arange(dcp_size, dtype=torch.int32)
.unsqueeze(0) .unsqueeze(0)
.repeat(num_requests, 1) .repeat(num_requests, 1)
) )
@ -1102,15 +1102,15 @@ def get_dcp_local_seq_lens(
) )
base = ( base = (
seq_lens_tiled seq_lens_tiled
// dcp_kv_cache_interleave_size // cp_kv_cache_interleave_size
// dcp_world_size // dcp_size
* dcp_kv_cache_interleave_size * cp_kv_cache_interleave_size
) )
remainder = seq_lens_tiled - base * dcp_world_size remainder = seq_lens_tiled - base * dcp_size
remainder = torch.clip( remainder = torch.clip(
remainder - rank_offsets * dcp_kv_cache_interleave_size, remainder - rank_offsets * cp_kv_cache_interleave_size,
0, 0,
dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size,
) )
dcp_local_seq_lens = base + remainder dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1) return dcp_local_seq_lens.squeeze(1)

View File

@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int,
): ):
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len self.max_model_len = max_model_len
@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_group_id=i, kv_cache_group_id=i,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
) )
@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
use_eagle: bool, use_eagle: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
False, False,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size > 1: if dcp_world_size > 1:
self.block_size *= dcp_world_size self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, ( assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group" "UnitaryKVCacheCoordinator assumes only one kv cache group"
) )
@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_spec=self.kv_cache_spec, kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
) )
return hit_blocks, len(hit_blocks[0]) * self.block_size return hit_blocks, len(hit_blocks[0]) * self.block_size
@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
assert dcp_world_size == 1, "DCP not support hybrid attn now." assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups() self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None: def verify_and_split_kv_cache_groups(self) -> None:
@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int,
) -> KVCacheCoordinator: ) -> KVCacheCoordinator:
if not enable_caching: if not enable_caching:
return KVCacheCoordinatorNoPrefixCache( return KVCacheCoordinatorNoPrefixCache(
@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
use_eagle, use_eagle,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
if len(kv_cache_config.kv_cache_groups) == 1: if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator( return UnitaryKVCacheCoordinator(
@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
return HybridKVCacheCoordinator( return HybridKVCacheCoordinator(
kv_cache_config, kv_cache_config,
@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )

View File

@ -100,6 +100,7 @@ class KVCacheManager:
log_stats: bool = False, log_stats: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None: ) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
@ -124,12 +125,9 @@ class KVCacheManager:
0 0
].kv_cache_spec.block_size ].kv_cache_spec.block_size
if dcp_world_size > 1: if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1 assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future self.block_size *= dcp_world_size * pcp_world_size
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.coordinator = get_kv_cache_coordinator( self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
@ -138,6 +136,7 @@ class KVCacheManager:
enable_caching=self.enable_caching, enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
) )
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool

View File

@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
// len(kv_cache_config.kv_cache_groups) // len(kv_cache_config.kv_cache_groups)
* min_block_size * min_block_size
) )
if vllm_config.parallel_config.decode_context_parallel_size > 1: dcp_size = vllm_config.parallel_config.decode_context_parallel_size
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
if pcp_size * dcp_size > 1:
num_tokens *= pcp_size * dcp_size
logger.info( logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.", "Multiplying the GPU KV cache size by the cp_world_size %d "
vllm_config.parallel_config.decode_context_parallel_size, "(pcp_world_size %d * dcp_world_size %d).",
pcp_size * dcp_size,
pcp_size,
dcp_size,
) )
num_tokens_str = f"{num_tokens:,}" num_tokens_str = f"{num_tokens:,}"
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local") logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")

View File

@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
self.block_size = block_size self.block_size = block_size
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# req_id -> Request # req_id -> Request
self.requests: dict[str, Request] = {} self.requests: dict[str, Request] = {}
@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
) )
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1

View File

@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_group_id: int, kv_cache_group_id: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None: ) -> None:
""" """
Initializes the SingleTypeKVCacheManager. Initializes the SingleTypeKVCacheManager.
@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
""" """
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1: self.pcp_world_size = pcp_world_size
self.block_size *= dcp_world_size if dcp_world_size * pcp_world_size > 1:
self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the longest cache hit prefix of the blocks that is not longer than Get the longest cache hit prefix of the blocks that is not longer than
@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids)) [] for _ in range(len(kv_cache_group_ids))
) )
block_size = kv_cache_spec.block_size block_size = kv_cache_spec.block_size
if dcp_world_size > 1: if dcp_world_size * pcp_world_size > 1:
block_size *= dcp_world_size block_size *= dcp_world_size * pcp_world_size
max_num_blocks = max_length // block_size max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks): for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), ( assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups" "SlidingWindowManager can only be used for sliding window groups"
) )
assert dcp_world_size == 1, "DCP not support sliding window attn now." assert dcp_world_size == 1, "DCP not support sliding window attn now."
assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit. # The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window # -1 since the input token itself is also included in the window
@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
For chunked local attention, we need to find the longest cache hit For chunked local attention, we need to find the longest cache hit
@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"Hybrid KV cache is not supported for " + "eagle + chunked local attention." "Hybrid KV cache is not supported for " + "eagle + chunked local attention."
) )
assert dcp_world_size == 1, "DCP not support chunked local attn now." assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0: if max_length > 0:
local_attention_start_idx = ( local_attention_start_idx = (
@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, MambaSpec), ( assert isinstance(kv_cache_spec, MambaSpec), (
"MambaManager can only be used for mamba groups" "MambaManager can only be used for mamba groups"
) )
assert dcp_world_size == 1, "DCP not support mamba now." assert dcp_world_size == 1, "DCP not support mamba now."
assert pcp_world_size == 1, "PCP not support mamba now."
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)) [] for _ in range(len(kv_cache_group_ids))
) )
@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), ( assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups" "CrossAttentionManager can only be used for cross-attention groups"

View File

@ -128,6 +128,7 @@ class EngineCore:
scheduler_block_size = ( scheduler_block_size = (
vllm_config.cache_config.block_size vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size * vllm_config.parallel_config.decode_context_parallel_size
* vllm_config.parallel_config.prefill_context_parallel_size
) )
self.scheduler: SchedulerInterface = Scheduler( self.scheduler: SchedulerInterface = Scheduler(

View File

@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
get_dp_group, get_dp_group,
get_ep_group, get_ep_group,
get_inner_dp_world_group, get_inner_dp_world_group,
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
f"({self.parallel_config.nnodes_within_dp}). " f"({self.parallel_config.nnodes_within_dp}). "
) )
self.local_world_size = self.parallel_config.local_world_size self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tp_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size pp_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, ( pcp_size = self.parallel_config.prefill_context_parallel_size
assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"tensor_parallel_size ({tp_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). " f"_parallel_size ({pp_size}) x prefill_context"
f"_parallel_size ({pcp_size}). "
) )
# Set multiprocessing envs # Set multiprocessing envs
@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
# 16-23, PP rank 2 # 16-23, PP rank 2
# 24-31, PP rank 3 # 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size return (
self.world_size
- self.parallel_config.tensor_parallel_size
* self.parallel_config.prefill_context_parallel_size
)
@dataclass @dataclass
@ -828,6 +835,8 @@ class WorkerProc:
dp_rank = get_dp_group().rank_in_group dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group pp_rank = get_pp_group().rank_in_group
pcp_size = get_pcp_group().world_size
pcp_rank = get_pcp_group().rank_in_group
tp_size = get_tp_group().world_size tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size dcp_size = get_dcp_group().world_size
@ -837,6 +846,8 @@ class WorkerProc:
process_name += f"_DP{dp_rank}" process_name += f"_DP{dp_rank}"
if pp_size > 1: if pp_size > 1:
process_name += f"_PP{pp_rank}" process_name += f"_PP{pp_rank}"
if pcp_size > 1:
process_name += f"_PCP{pcp_rank}"
if tp_size > 1: if tp_size > 1:
process_name += f"_TP{tp_rank}" process_name += f"_TP{tp_rank}"
if dcp_size > 1: if dcp_size > 1:

View File

@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# Note(hc): each dcp rank only need save # Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally. # (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1: if dcp_world_size * pcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size) max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod @classmethod

View File

@ -4,7 +4,7 @@
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
@ -22,7 +22,7 @@ class BlockTable:
pin_memory: bool, pin_memory: bool,
device: torch.device, device: torch.device,
kernel_block_size: int, kernel_block_size: int,
dcp_kv_cache_interleave_size: int, cp_kv_cache_interleave_size: int,
): ):
""" """
Args: Args:
@ -80,6 +80,13 @@ class BlockTable:
else: else:
self._kernel_block_arange = None self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try: try:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
@ -87,7 +94,7 @@ class BlockTable:
# DCP might not be initialized in testing # DCP might not be initialized in testing
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row( def append_row(
self, self,
@ -131,14 +138,16 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size` # NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by # here because M (max_model_len) is not necessarily divisible by
# block_size. # block_size.
if self.dcp_world_size > 1: total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave # Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is # style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size: # always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size # Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation. # for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = ( block_table_indices = (
req_indices * self.max_num_blocks_per_req req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size + positions // virtual_block_size
@ -150,16 +159,16 @@ class BlockTable:
virtual_block_offsets = positions % virtual_block_size virtual_block_offsets = positions % virtual_block_size
mask = ( mask = (
virtual_block_offsets virtual_block_offsets
// self.dcp_kv_cache_interleave_size // self.cp_kv_cache_interleave_size
% self.dcp_world_size % total_cp_world_size
== self.dcp_rank == total_cp_rank
) )
# Calculate local block_offsets # Calculate local block_offsets
block_offsets = ( block_offsets = (
virtual_block_offsets virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size) // (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size * self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
) )
# Calculate slot_mapping # Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets slot_mapping = block_numbers * self.block_size + block_offsets
@ -253,7 +262,7 @@ class MultiGroupBlockTable:
block_sizes: list[int], block_sizes: list[int],
kernel_block_sizes: list[int], kernel_block_sizes: list[int],
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
) -> None: ) -> None:
# Note(hc): each dcp rank only store # Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache, # (max_model_len//dcp_world_size) tokens in kvcache,
@ -283,7 +292,7 @@ class MultiGroupBlockTable:
pin_memory, pin_memory,
device, device,
kernel_block_size, kernel_block_size,
dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size,
) )
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
] ]

View File

@ -87,7 +87,7 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
num_speculative_tokens: int = 0, num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
): ):
self.is_pooling_model = is_pooling_model self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
@ -141,7 +141,7 @@ class InputBatch:
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes, kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens, num_speculative_tokens=num_speculative_tokens,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
) )
# Sampling-related. # Sampling-related.

View File

@ -426,7 +426,7 @@ class GPUModelRunner(
# uses output token ids so we set this conservatively. # uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs), logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
) )
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
@ -1436,7 +1436,7 @@ class GPUModelRunner(
self.seq_lens.cpu[:num_reqs], self.seq_lens.cpu[:num_reqs],
self.dcp_world_size, self.dcp_world_size,
self.dcp_rank, self.dcp_rank,
self.parallel_config.dcp_kv_cache_interleave_size, self.parallel_config.cp_kv_cache_interleave_size,
) )
self.dcp_local_seq_lens.copy_to_gpu(num_reqs) self.dcp_local_seq_lens.copy_to_gpu(num_reqs)

View File

@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
@ -733,6 +734,7 @@ class Worker(WorkerBase):
module.global_num_experts = module.moe_config.num_experts module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make( module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size, tp_size_=get_tp_group().world_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config, vllm_parallel_config=parallel_config,
) )
@ -886,6 +888,7 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size,
parallel_config.prefill_context_parallel_size,
parallel_config.decode_context_parallel_size, parallel_config.decode_context_parallel_size,
) )