[Feature] Prefill Context Parallel (PCP) basic support (#28718)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: LookAround <lixushi@huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
Qiu 2025-11-20 04:52:44 +08:00 committed by GitHub
parent 02f5903b84
commit 2fd893b4ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 399 additions and 114 deletions

View File

@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
cp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool
@ -55,7 +55,7 @@ class CPTestSettings:
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
@ -71,7 +71,7 @@ class CPTestSettings:
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
@ -116,7 +116,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
@ -197,7 +197,7 @@ def _compare_cp_with_tp(
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
@ -227,7 +227,7 @@ CP_TEXT_GENERATION_MODELS = {
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),

View File

@ -15,7 +15,11 @@ from tests.kernels.quantization.nvfp4_utils import (
)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@ -561,6 +565,7 @@ def make_modular_kernel(
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)

View File

@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8
cp_kv_cache_interleave_size = 8
block_table = BlockTable(
block_size=block_size,
@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Verify hybrid block configuration

View File

@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int
pcp_world_size: int
pcp_rank: int
total_cp_world_size: int
total_cp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
@ -278,6 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)

View File

@ -169,12 +169,11 @@ def correct_attn_out(
return out, lse
def cp_lse_ag_out_rs(
def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
ctx: CPTritonContext | None = None,
):
"""
cp_attn_out: [ B, H, D ]
@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
return out, lse
def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1)
if return_lse:
@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
return out
def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.all_reduce(out)
if return_lse:
return out, lse
return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]

View File

@ -71,6 +71,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
@ -239,14 +241,25 @@ class ParallelConfig:
needs to be divisible by dcp_size."""
dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""
Interleave size of kv_cache storage while using DCP.
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
and will be deprecated when PCP is fully supported.
"""
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
store interleave_size tokens on total_cp_rank i,
then store next interleave_size tokens on taotal_cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
first populated to the preceding ranks. Tokens are then stored
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
Block_size should be divisible by cp_kv_cache_interleave_size.
"""
_api_process_count: int = Field(default=1, gt=0)
@ -311,6 +324,11 @@ class ParallelConfig:
"num_redundant_experts."
)
if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self
@property
@ -529,7 +547,11 @@ class ParallelConfig:
)
# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)
if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")

View File

@ -481,6 +481,14 @@ class VllmConfig:
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif self.parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
@ -610,22 +618,34 @@ class VllmConfig:
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.dcp_kv_cache_interleave_size
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:

View File

@ -1098,6 +1098,12 @@ get_context_model_parallel_group = get_dcp_group
_PP: GroupCoordinator | None = None
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
_DP: GroupCoordinator | None = None
@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator:
return _EP
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
_PCP: GroupCoordinator | None = None
def get_pcp_group() -> GroupCoordinator:
assert _PCP is not None, "prefill context parallel group is not initialized"
return _PCP
@deprecated(
@ -1276,6 +1285,7 @@ def init_distributed_environment(
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1,
backend: str | None = None,
) -> None:
@ -1325,7 +1335,11 @@ def initialize_model_parallel(
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
-1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
-1,
data_parallel_size,
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
) # noqa
# Build the tensor model-parallel groups.
@ -1360,11 +1374,23 @@ def initialize_model_parallel(
group_name="dcp",
)
global _PCP
assert _PCP is None, "prefill context parallel group is already initialized"
group_ranks = (
all_ranks.transpose(3, 4)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
# Build the pipeline model-parallel groups.
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
group_ranks = (
all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
@ -1373,7 +1399,7 @@ def initialize_model_parallel(
global _DP
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
@ -1383,7 +1409,12 @@ def initialize_model_parallel(
assert _EP is None, "expert parallel group is already initialized"
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(-1, data_parallel_size * tensor_model_parallel_size)
.reshape(
-1,
data_parallel_size
* prefill_context_model_parallel_size
* tensor_model_parallel_size,
)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
@ -1393,11 +1424,13 @@ def initialize_model_parallel(
logger.info_once(
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s",
"DP rank %s, PP rank %s, PCP rank %s, "
"TP rank %s, EP rank %s",
rank,
world_size,
_DP.rank_in_group,
_PP.rank_in_group,
_PCP.rank_in_group,
_TP.rank_in_group,
_EP.rank_in_group,
)
@ -1406,6 +1439,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
prefill_context_model_parallel_size: int = 1,
decode_context_model_parallel_size: int | None = 1,
backend: str | None = None,
) -> None:
@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized(
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
decode_context_model_parallel_size,
backend,
)
@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized(
f"got: {pp_world_size=} vs. "
f"wanted: {pipeline_model_parallel_size=}"
)
pcp_world_size = get_pcp_group().world_size
assert pcp_world_size == prefill_context_model_parallel_size, (
"prefill context parallel group already initialized, but of unexpected size: "
f"{pcp_world_size=} vs. "
f"{prefill_context_model_parallel_size=}"
)
def prepare_communication_buffer_for_model(model: torch.nn.Module):
@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
"""
if _TP is not None:
_TP.prepare_communication_buffer_for_model(model)
if _PCP is not None:
_PCP.prepare_communication_buffer_for_model(model)
if _PP is not None:
_PP.prepare_communication_buffer_for_model(model)
if _DP is not None:
@ -1520,16 +1563,21 @@ def destroy_model_parallel():
_TP.destroy()
_TP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
global _PCP
if _PCP:
_PCP.destroy()
_PCP = None
global _PP
if _PP:
_PP.destroy()
_PP = None
global _DP
if _DP:
_DP.destroy()

View File

@ -389,8 +389,10 @@ class EngineArgs:
nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
@ -770,6 +772,15 @@ class EngineArgs:
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--cp-kv-cache-interleave-size",
**parallel_kwargs["cp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--prefill-context-parallel-size",
"-pcp",
**parallel_kwargs["prefill_context_parallel_size"],
)
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
@ -1600,6 +1611,7 @@ class EngineArgs:
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
prefill_context_parallel_size=self.prefill_context_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
@ -1631,6 +1643,7 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
@ -1952,6 +1965,15 @@ class EngineArgs:
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
default_prefix_caching = False
logger.warning(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default."
)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill

View File

@ -8,7 +8,11 @@ import torch
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
@ -684,9 +688,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass
class FusedMoEParallelConfig:
tp_size: int
pcp_size: int
dp_size: int
ep_size: int
tp_rank: int
pcp_rank: int
dp_rank: int
ep_rank: int
@ -713,19 +719,22 @@ class FusedMoEParallelConfig:
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod
def flatten_tp_across_dp(
tp_size: int, dp_size: int, dp_rank: int
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
) -> tuple[int, int]:
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size * tp_size devices. Update tp_size
# and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * tp_size
flatten_tp_rank = dp_rank * tp_size + tp_rank
# There are actually dp_size * pcp_size * tp_size devices.
# Update tp_size and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * pcp_size * tp_size
flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
return flatten_tp_size, flatten_tp_rank
@staticmethod
def make(
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
tp_size_: int,
pcp_size_: int,
dp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input `tp_size_`,
@ -734,19 +743,22 @@ class FusedMoEParallelConfig:
Args:
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag.
Examples:
When there is no parallelism requested,
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
unaltered and the ranks set to 0.
Expert Parallelism is considered only when either `dp_size_` or
Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
`tp_size_` is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
Note that PCP serves the same function as DP here.
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
@ -754,7 +766,7 @@ class FusedMoEParallelConfig:
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
@ -762,7 +774,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
@ -772,14 +784,14 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
@ -787,7 +799,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
@ -798,18 +810,25 @@ class FusedMoEParallelConfig:
between the 4 devices.
"""
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
use_ep = (
dp_size_ * pcp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel
)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size_, dp_size_, dp_rank
pcp_size = pcp_size_
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
)
if not use_ep:
return FusedMoEParallelConfig(
tp_size=tp_size,
tp_rank=tp_rank,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
@ -826,6 +845,8 @@ class FusedMoEParallelConfig:
return FusedMoEParallelConfig(
tp_size=1,
tp_rank=0,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,

View File

@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
@ -343,6 +344,7 @@ class FusedMoE(CustomOp):
tp_size: int | None = None,
ep_size: int | None = None,
dp_size: int | None = None,
pcp_size: int | None = None,
prefix: str = "",
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
@ -398,12 +400,14 @@ class FusedMoE(CustomOp):
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
self.is_sequence_parallel = is_sequence_parallel
self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=tp_size_,
pcp_size_=pcp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config,
)
@ -679,6 +683,10 @@ class FusedMoE(CustomOp):
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def pcp_size(self):
return self.moe_parallel_config.pcp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@ -691,6 +699,10 @@ class FusedMoE(CustomOp):
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def pcp_rank(self):
return self.moe_parallel_config.pcp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@ -1871,6 +1883,19 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@ -1925,6 +1950,13 @@ class FusedMoE(CustomOp):
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
if self.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:

View File

@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -322,10 +323,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
@ -507,10 +510,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size

View File

@ -265,8 +265,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
self.cp_kv_cache_interleave_size = (
self.parallel_config.cp_kv_cache_interleave_size
)
self.use_full_cuda_graph = (
@ -388,7 +388,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
self.cp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()

View File

@ -536,7 +536,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
# Don't try to access the runner on AMD
@ -1289,8 +1289,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
get_current_vllm_config()
)
)
self.dcp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
self.cp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
)
def _flash_attn_varlen_diff_headdims(

View File

@ -1080,9 +1080,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
def get_dcp_local_seq_lens(
seq_lens: torch.Tensor,
dcp_world_size: int = 1,
dcp_size: int = 1,
dcp_rank: int | None = None,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank.
@ -1091,7 +1091,7 @@ def get_dcp_local_seq_lens(
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_world_size, dtype=torch.int32)
torch.arange(dcp_size, dtype=torch.int32)
.unsqueeze(0)
.repeat(num_requests, 1)
)
@ -1102,15 +1102,15 @@ def get_dcp_local_seq_lens(
)
base = (
seq_lens_tiled
// dcp_kv_cache_interleave_size
// dcp_world_size
* dcp_kv_cache_interleave_size
// cp_kv_cache_interleave_size
// dcp_size
* cp_kv_cache_interleave_size
)
remainder = seq_lens_tiled - base * dcp_world_size
remainder = seq_lens_tiled - base * dcp_size
remainder = torch.clip(
remainder - rank_offsets * dcp_kv_cache_interleave_size,
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)

View File

@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)
@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
use_eagle: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)

View File

@ -100,6 +100,7 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None:
self.max_model_len = max_model_len
@ -124,12 +125,9 @@ class KVCacheManager:
0
].kv_cache_spec.block_size
if dcp_world_size > 1:
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
@ -138,6 +136,7 @@ class KVCacheManager:
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool

View File

@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
// len(kv_cache_config.kv_cache_groups)
* min_block_size
)
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
if pcp_size * dcp_size > 1:
num_tokens *= pcp_size * dcp_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size,
"Multiplying the GPU KV cache size by the cp_world_size %d "
"(pcp_world_size %d * dcp_world_size %d).",
pcp_size * dcp_size,
pcp_size,
dcp_size,
)
num_tokens_str = f"{num_tokens:,}"
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")

View File

@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
self.block_size = block_size
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# req_id -> Request
self.requests: dict[str, Request] = {}
@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1

View File

@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1:
self.block_size *= dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size * pcp_world_size > 1:
self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
if dcp_world_size > 1:
block_size *= dcp_world_size
if dcp_world_size * pcp_world_size > 1:
block_size *= dcp_world_size * pcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups"
)
assert dcp_world_size == 1, "DCP not support sliding window attn now."
assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, MambaSpec), (
"MambaManager can only be used for mamba groups"
)
assert dcp_world_size == 1, "DCP not support mamba now."
assert pcp_world_size == 1, "PCP not support mamba now."
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))
)
@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"

View File

@ -128,6 +128,7 @@ class EngineCore:
scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
* vllm_config.parallel_config.prefill_context_parallel_size
)
self.scheduler: SchedulerInterface = Scheduler(

View File

@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
get_dp_group,
get_ep_group,
get_inner_dp_world_group,
get_pcp_group,
get_pp_group,
get_tp_group,
)
@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
tp_size = self.parallel_config.tensor_parallel_size
pp_size = self.parallel_config.pipeline_parallel_size
pcp_size = self.parallel_config.prefill_context_parallel_size
assert self.world_size == tp_size * pp_size * pcp_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}). "
f"tensor_parallel_size ({tp_size}) x pipeline"
f"_parallel_size ({pp_size}) x prefill_context"
f"_parallel_size ({pcp_size}). "
)
# Set multiprocessing envs
@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
return (
self.world_size
- self.parallel_config.tensor_parallel_size
* self.parallel_config.prefill_context_parallel_size
)
@dataclass
@ -828,6 +835,8 @@ class WorkerProc:
dp_rank = get_dp_group().rank_in_group
pp_size = get_pp_group().world_size
pp_rank = get_pp_group().rank_in_group
pcp_size = get_pcp_group().world_size
pcp_rank = get_pcp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
@ -837,6 +846,8 @@ class WorkerProc:
process_name += f"_DP{dp_rank}"
if pp_size > 1:
process_name += f"_PP{pp_rank}"
if pcp_size > 1:
process_name += f"_PCP{pcp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:

View File

@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
if dcp_world_size * pcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod

View File

@ -4,7 +4,7 @@
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
@ -22,7 +22,7 @@ class BlockTable:
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
dcp_kv_cache_interleave_size: int,
cp_kv_cache_interleave_size: int,
):
"""
Args:
@ -80,6 +80,13 @@ class BlockTable:
else:
self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@ -87,7 +94,7 @@ class BlockTable:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
@ -131,14 +138,16 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size > 1:
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
@ -150,16 +159,16 @@ class BlockTable:
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.dcp_kv_cache_interleave_size
% self.dcp_world_size
== self.dcp_rank
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
@ -253,7 +262,7 @@ class MultiGroupBlockTable:
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
@ -283,7 +292,7 @@ class MultiGroupBlockTable:
pin_memory,
device,
kernel_block_size,
dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]

View File

@ -87,7 +87,7 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
@ -141,7 +141,7 @@ class InputBatch:
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.

View File

@ -426,7 +426,7 @@ class GPUModelRunner(
# uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
@ -1436,7 +1436,7 @@ class GPUModelRunner(
self.seq_lens.cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.dcp_kv_cache_interleave_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)

View File

@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
get_pcp_group,
get_pp_group,
get_tp_group,
)
@ -733,6 +734,7 @@ class Worker(WorkerBase):
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
@ -886,6 +888,7 @@ def init_worker_distributed_environment(
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.prefill_context_parallel_size,
parallel_config.decode_context_parallel_size,
)