mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 01:49:07 +08:00
[Feature] Prefill Context Parallel (PCP) basic support (#28718)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com> Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com> Co-authored-by: LookAround <lixushi@huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
parent
02f5903b84
commit
2fd893b4ce
@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
|
|||||||
tp_size: int
|
tp_size: int
|
||||||
pp_size: int
|
pp_size: int
|
||||||
dcp_size: int
|
dcp_size: int
|
||||||
dcp_kv_cache_interleave_size: int
|
cp_kv_cache_interleave_size: int
|
||||||
eager_mode: bool
|
eager_mode: bool
|
||||||
chunked_prefill: bool
|
chunked_prefill: bool
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class CPTestSettings:
|
|||||||
tp_base: int = 4,
|
tp_base: int = 4,
|
||||||
pp_base: int = 1,
|
pp_base: int = 1,
|
||||||
dcp_base: int = 1,
|
dcp_base: int = 1,
|
||||||
dcp_kv_cache_interleave_size: int = 1,
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
multi_node_only: bool = False,
|
multi_node_only: bool = False,
|
||||||
runner: RunnerOption = "auto",
|
runner: RunnerOption = "auto",
|
||||||
load_format: str | None = None,
|
load_format: str | None = None,
|
||||||
@ -71,7 +71,7 @@ class CPTestSettings:
|
|||||||
tp_size=tp_base,
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
dcp_size=int(dcp_multiplier * tp_base),
|
dcp_size=int(dcp_multiplier * tp_base),
|
||||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val,
|
chunked_prefill=chunked_prefill_val,
|
||||||
)
|
)
|
||||||
@ -116,7 +116,7 @@ def _compare_cp_with_tp(
|
|||||||
tp_size,
|
tp_size,
|
||||||
pp_size,
|
pp_size,
|
||||||
dcp_size,
|
dcp_size,
|
||||||
dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size,
|
||||||
eager_mode,
|
eager_mode,
|
||||||
chunked_prefill,
|
chunked_prefill,
|
||||||
) = parallel_setup
|
) = parallel_setup
|
||||||
@ -197,7 +197,7 @@ def _compare_cp_with_tp(
|
|||||||
"--decode-context-parallel-size",
|
"--decode-context-parallel-size",
|
||||||
str(dcp_size),
|
str(dcp_size),
|
||||||
"--dcp-kv-cache-interleave-size",
|
"--dcp-kv-cache-interleave-size",
|
||||||
str(dcp_kv_cache_interleave_size),
|
str(cp_kv_cache_interleave_size),
|
||||||
"--distributed-executor-backend",
|
"--distributed-executor-backend",
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
]
|
]
|
||||||
@ -227,7 +227,7 @@ CP_TEXT_GENERATION_MODELS = {
|
|||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
CPTestSettings.detailed(tp_base=2),
|
CPTestSettings.detailed(tp_base=2),
|
||||||
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
|
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
|
||||||
],
|
],
|
||||||
"bigcode/gpt_bigcode-santacoder": [
|
"bigcode/gpt_bigcode-santacoder": [
|
||||||
CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
|
|||||||
@ -15,7 +15,11 @@ from tests.kernels.quantization.nvfp4_utils import (
|
|||||||
)
|
)
|
||||||
from tests.kernels.utils import torch_experts
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import (
|
||||||
|
get_dp_group,
|
||||||
|
get_pcp_group,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig,
|
FusedMoEConfig,
|
||||||
@ -561,6 +565,7 @@ def make_modular_kernel(
|
|||||||
# make moe config
|
# make moe config
|
||||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||||
tp_size_=get_tensor_model_parallel_world_size(),
|
tp_size_=get_tensor_model_parallel_world_size(),
|
||||||
|
pcp_size_=get_pcp_group().world_size,
|
||||||
dp_size_=get_dp_group().world_size,
|
dp_size_=get_dp_group().world_size,
|
||||||
vllm_parallel_config=vllm_config.parallel_config,
|
vllm_parallel_config=vllm_config.parallel_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
|
|||||||
max_num_reqs = 10
|
max_num_reqs = 10
|
||||||
max_num_blocks_per_req = 20
|
max_num_blocks_per_req = 20
|
||||||
max_num_batched_tokens = 512
|
max_num_batched_tokens = 512
|
||||||
dcp_kv_cache_interleave_size = 8
|
cp_kv_cache_interleave_size = 8
|
||||||
|
|
||||||
block_table = BlockTable(
|
block_table = BlockTable(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
device=torch.device(DEVICE),
|
device=torch.device(DEVICE),
|
||||||
kernel_block_size=kernel_block_sizes[0],
|
kernel_block_size=kernel_block_sizes[0],
|
||||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify hybrid block configuration
|
# Verify hybrid block configuration
|
||||||
|
|||||||
@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
dcp_world_size: int
|
dcp_world_size: int
|
||||||
dcp_rank: int
|
dcp_rank: int
|
||||||
|
|
||||||
|
pcp_world_size: int
|
||||||
|
pcp_rank: int
|
||||||
|
|
||||||
|
total_cp_world_size: int
|
||||||
|
total_cp_rank: int
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
# use __new__ so that all subclasses will call this
|
# use __new__ so that all subclasses will call this
|
||||||
self = super().__new__(cls)
|
self = super().__new__(cls)
|
||||||
@ -278,6 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_pcp_group
|
||||||
|
|
||||||
|
self.pcp_world_size = get_pcp_group().world_size
|
||||||
|
self.pcp_rank = get_pcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
self.pcp_world_size = 1
|
||||||
|
self.pcp_rank = 0
|
||||||
|
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||||
|
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||||
|
|
||||||
self.need_to_return_lse_for_decode = (
|
self.need_to_return_lse_for_decode = (
|
||||||
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
||||||
)
|
)
|
||||||
|
|||||||
@ -169,12 +169,11 @@ def correct_attn_out(
|
|||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
def cp_lse_ag_out_rs(
|
def _cp_lse_common(
|
||||||
cp_attn_out: torch.Tensor,
|
cp_attn_out: torch.Tensor,
|
||||||
cp_attn_lse: torch.Tensor,
|
cp_attn_lse: torch.Tensor,
|
||||||
cp_group: GroupCoordinator,
|
cp_group: GroupCoordinator,
|
||||||
ctx: CPTritonContext = None,
|
ctx: CPTritonContext | None = None,
|
||||||
return_lse=False,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
cp_attn_out: [ B, H, D ]
|
cp_attn_out: [ B, H, D ]
|
||||||
@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
|
|||||||
cp_attn_lse = cp_attn_lse.contiguous()
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
|
assert out.is_contiguous()
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
def cp_lse_ag_out_rs(
|
||||||
|
cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext | None = None,
|
||||||
|
return_lse: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
cp_attn_out: [ B, H, D ]
|
||||||
|
cp_attn_lse: [ B, H ]
|
||||||
|
"""
|
||||||
|
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||||
out = cp_group.reduce_scatter(out, dim=1)
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
|
||||||
if return_lse:
|
if return_lse:
|
||||||
@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def cp_lse_ag_out_ar(
|
||||||
|
cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext | None = None,
|
||||||
|
return_lse: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
cp_attn_out: [ B, H, D ]
|
||||||
|
cp_attn_lse: [ B, H ]
|
||||||
|
"""
|
||||||
|
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||||
|
out = cp_group.all_reduce(out)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
return out, lse
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _pack_seq_kernel(
|
def _pack_seq_kernel(
|
||||||
x_ptr, # [N, D]
|
x_ptr, # [N, D]
|
||||||
|
|||||||
@ -71,6 +71,8 @@ class ParallelConfig:
|
|||||||
"""Number of pipeline parallel groups."""
|
"""Number of pipeline parallel groups."""
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
"""Number of tensor parallel groups."""
|
"""Number of tensor parallel groups."""
|
||||||
|
prefill_context_parallel_size: int = 1
|
||||||
|
"""Number of prefill context parallel groups."""
|
||||||
data_parallel_size: int = 1
|
data_parallel_size: int = 1
|
||||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||||
the product of the tensor parallel size and data parallel size."""
|
the product of the tensor parallel size and data parallel size."""
|
||||||
@ -239,14 +241,25 @@ class ParallelConfig:
|
|||||||
needs to be divisible by dcp_size."""
|
needs to be divisible by dcp_size."""
|
||||||
|
|
||||||
dcp_kv_cache_interleave_size: int = 1
|
dcp_kv_cache_interleave_size: int = 1
|
||||||
"""Interleave size of kv_cache storage while using dcp or cp > 1,
|
"""
|
||||||
store interleave_size tokens on (d)cp i,
|
Interleave size of kv_cache storage while using DCP.
|
||||||
then store next interleave_size tokens on (d)cp i+1.
|
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
|
||||||
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
|
and will be deprecated when PCP is fully supported.
|
||||||
Interleave_size=block_size: block-level align, first fill the block on first rank,
|
|
||||||
token is stored on rank i+1 block j after rank i block j is full.
|
"""
|
||||||
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
|
cp_kv_cache_interleave_size: int = 1
|
||||||
Block_size should be divisible by dcp_kv_cache_interleave_size.
|
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||||
|
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||||
|
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
|
||||||
|
store interleave_size tokens on total_cp_rank i,
|
||||||
|
then store next interleave_size tokens on taotal_cp_rank i+1.
|
||||||
|
Interleave_size=1: token-level alignment, where token `i` is stored on
|
||||||
|
total_cp_rank `i % total_cp_world_size`.
|
||||||
|
Interleave_size=block_size: block-level alignment, where tokens are
|
||||||
|
first populated to the preceding ranks. Tokens are then stored
|
||||||
|
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
|
||||||
|
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
|
||||||
|
Block_size should be divisible by cp_kv_cache_interleave_size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_api_process_count: int = Field(default=1, gt=0)
|
_api_process_count: int = Field(default=1, gt=0)
|
||||||
@ -311,6 +324,11 @@ class ParallelConfig:
|
|||||||
"num_redundant_experts."
|
"num_redundant_experts."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.prefill_context_parallel_size > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prefill context parallelism is not fully supported. "
|
||||||
|
"Please set prefill_context_parallel_size to 1."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -529,7 +547,11 @@ class ParallelConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Continue with the rest of the initialization
|
# Continue with the rest of the initialization
|
||||||
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
self.world_size = (
|
||||||
|
self.pipeline_parallel_size
|
||||||
|
* self.tensor_parallel_size
|
||||||
|
* self.prefill_context_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.distributed_executor_backend == "external_launcher":
|
if self.distributed_executor_backend == "external_launcher":
|
||||||
logger.info("Using external launcher for distributed inference.")
|
logger.info("Using external launcher for distributed inference.")
|
||||||
|
|||||||
@ -481,6 +481,14 @@ class VllmConfig:
|
|||||||
"Overriding cudagraph_mode to PIECEWISE."
|
"Overriding cudagraph_mode to PIECEWISE."
|
||||||
)
|
)
|
||||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
# prefill context parallel do not support full cudagraphs
|
||||||
|
elif self.parallel_config.prefill_context_parallel_size > 1:
|
||||||
|
logger.warning_once(
|
||||||
|
"Prefill context parallel (PCP) is enabled, which is "
|
||||||
|
"incompatible with full CUDA graphs. "
|
||||||
|
"Overriding cudagraph_mode to PIECEWISE."
|
||||||
|
)
|
||||||
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
elif self.model_config is not None:
|
elif self.model_config is not None:
|
||||||
if self.model_config.pooler_config is not None:
|
if self.model_config.pooler_config is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -610,22 +618,34 @@ class VllmConfig:
|
|||||||
|
|
||||||
# If DCP, ensure the block size is right.
|
# If DCP, ensure the block size is right.
|
||||||
if self.parallel_config.decode_context_parallel_size > 1:
|
if self.parallel_config.decode_context_parallel_size > 1:
|
||||||
|
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
|
||||||
|
self.parallel_config.cp_kv_cache_interleave_size
|
||||||
|
!= self.parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
):
|
||||||
|
self.parallel_config.cp_kv_cache_interleave_size = (
|
||||||
|
self.parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
)
|
||||||
|
logger.warning_once(
|
||||||
|
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
|
||||||
|
"_interleave_size. And dcp-kv-cache-interleave-size will be "
|
||||||
|
"deprecated when PCP is fully supported."
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
self.parallel_config.dcp_kv_cache_interleave_size
|
self.parallel_config.cp_kv_cache_interleave_size
|
||||||
<= self.cache_config.block_size
|
<= self.cache_config.block_size
|
||||||
and self.cache_config.block_size
|
and self.cache_config.block_size
|
||||||
% self.parallel_config.dcp_kv_cache_interleave_size
|
% self.parallel_config.cp_kv_cache_interleave_size
|
||||||
== 0
|
== 0
|
||||||
), (
|
), (
|
||||||
f"Block_size({self.cache_config.block_size}) should be greater "
|
f"Block_size({self.cache_config.block_size}) should be greater "
|
||||||
"than or equal to and divisible by dcp_kv_cache_interleave_size "
|
"than or equal to and divisible by cp_kv_cache_interleave_size "
|
||||||
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
self.parallel_config.dcp_kv_cache_interleave_size == 1
|
self.parallel_config.cp_kv_cache_interleave_size == 1
|
||||||
or self.speculative_config is None
|
or self.speculative_config is None
|
||||||
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
|
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
||||||
|
|
||||||
# Do this after all the updates to compilation_config.mode
|
# Do this after all the updates to compilation_config.mode
|
||||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||||
|
|||||||
@ -1098,6 +1098,12 @@ get_context_model_parallel_group = get_dcp_group
|
|||||||
|
|
||||||
_PP: GroupCoordinator | None = None
|
_PP: GroupCoordinator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pp_group() -> GroupCoordinator:
|
||||||
|
assert _PP is not None, "pipeline model parallel group is not initialized"
|
||||||
|
return _PP
|
||||||
|
|
||||||
|
|
||||||
_DP: GroupCoordinator | None = None
|
_DP: GroupCoordinator | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator:
|
|||||||
return _EP
|
return _EP
|
||||||
|
|
||||||
|
|
||||||
def get_pp_group() -> GroupCoordinator:
|
_PCP: GroupCoordinator | None = None
|
||||||
assert _PP is not None, "pipeline model parallel group is not initialized"
|
|
||||||
return _PP
|
|
||||||
|
def get_pcp_group() -> GroupCoordinator:
|
||||||
|
assert _PCP is not None, "prefill context parallel group is not initialized"
|
||||||
|
return _PCP
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -1276,6 +1285,7 @@ def init_distributed_environment(
|
|||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
prefill_context_model_parallel_size: int = 1,
|
||||||
decode_context_model_parallel_size: int | None = 1,
|
decode_context_model_parallel_size: int | None = 1,
|
||||||
backend: str | None = None,
|
backend: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1325,7 +1335,11 @@ def initialize_model_parallel(
|
|||||||
# to get group_ranks for each dimension, transpose that dimension to the
|
# to get group_ranks for each dimension, transpose that dimension to the
|
||||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||||
all_ranks = torch.arange(world_size).reshape(
|
all_ranks = torch.arange(world_size).reshape(
|
||||||
-1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
|
-1,
|
||||||
|
data_parallel_size,
|
||||||
|
pipeline_model_parallel_size,
|
||||||
|
prefill_context_model_parallel_size,
|
||||||
|
tensor_model_parallel_size,
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
# Build the tensor model-parallel groups.
|
# Build the tensor model-parallel groups.
|
||||||
@ -1360,11 +1374,23 @@ def initialize_model_parallel(
|
|||||||
group_name="dcp",
|
group_name="dcp",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
global _PCP
|
||||||
|
assert _PCP is None, "prefill context parallel group is already initialized"
|
||||||
|
group_ranks = (
|
||||||
|
all_ranks.transpose(3, 4)
|
||||||
|
.reshape(-1, prefill_context_model_parallel_size)
|
||||||
|
.unbind(0)
|
||||||
|
)
|
||||||
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
_PCP = init_model_parallel_group(
|
||||||
|
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||||
|
)
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
global _PP
|
global _PP
|
||||||
assert _PP is None, "pipeline model parallel group is already initialized"
|
assert _PP is None, "pipeline model parallel group is already initialized"
|
||||||
group_ranks = (
|
group_ranks = (
|
||||||
all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||||
)
|
)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
_PP = init_model_parallel_group(
|
_PP = init_model_parallel_group(
|
||||||
@ -1373,7 +1399,7 @@ def initialize_model_parallel(
|
|||||||
|
|
||||||
global _DP
|
global _DP
|
||||||
assert _DP is None, "data parallel group is already initialized"
|
assert _DP is None, "data parallel group is already initialized"
|
||||||
group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
|
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
_DP = init_model_parallel_group(
|
_DP = init_model_parallel_group(
|
||||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||||
@ -1383,7 +1409,12 @@ def initialize_model_parallel(
|
|||||||
assert _EP is None, "expert parallel group is already initialized"
|
assert _EP is None, "expert parallel group is already initialized"
|
||||||
group_ranks = (
|
group_ranks = (
|
||||||
all_ranks.transpose(1, 2)
|
all_ranks.transpose(1, 2)
|
||||||
.reshape(-1, data_parallel_size * tensor_model_parallel_size)
|
.reshape(
|
||||||
|
-1,
|
||||||
|
data_parallel_size
|
||||||
|
* prefill_context_model_parallel_size
|
||||||
|
* tensor_model_parallel_size,
|
||||||
|
)
|
||||||
.unbind(0)
|
.unbind(0)
|
||||||
)
|
)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
@ -1393,11 +1424,13 @@ def initialize_model_parallel(
|
|||||||
|
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"rank %s in world size %s is assigned as "
|
"rank %s in world size %s is assigned as "
|
||||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s",
|
"DP rank %s, PP rank %s, PCP rank %s, "
|
||||||
|
"TP rank %s, EP rank %s",
|
||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
_DP.rank_in_group,
|
_DP.rank_in_group,
|
||||||
_PP.rank_in_group,
|
_PP.rank_in_group,
|
||||||
|
_PCP.rank_in_group,
|
||||||
_TP.rank_in_group,
|
_TP.rank_in_group,
|
||||||
_EP.rank_in_group,
|
_EP.rank_in_group,
|
||||||
)
|
)
|
||||||
@ -1406,6 +1439,7 @@ def initialize_model_parallel(
|
|||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size: int,
|
tensor_model_parallel_size: int,
|
||||||
pipeline_model_parallel_size: int,
|
pipeline_model_parallel_size: int,
|
||||||
|
prefill_context_model_parallel_size: int = 1,
|
||||||
decode_context_model_parallel_size: int | None = 1,
|
decode_context_model_parallel_size: int | None = 1,
|
||||||
backend: str | None = None,
|
backend: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized(
|
|||||||
initialize_model_parallel(
|
initialize_model_parallel(
|
||||||
tensor_model_parallel_size,
|
tensor_model_parallel_size,
|
||||||
pipeline_model_parallel_size,
|
pipeline_model_parallel_size,
|
||||||
|
prefill_context_model_parallel_size,
|
||||||
decode_context_model_parallel_size,
|
decode_context_model_parallel_size,
|
||||||
backend,
|
backend,
|
||||||
)
|
)
|
||||||
@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized(
|
|||||||
f"got: {pp_world_size=} vs. "
|
f"got: {pp_world_size=} vs. "
|
||||||
f"wanted: {pipeline_model_parallel_size=}"
|
f"wanted: {pipeline_model_parallel_size=}"
|
||||||
)
|
)
|
||||||
|
pcp_world_size = get_pcp_group().world_size
|
||||||
|
assert pcp_world_size == prefill_context_model_parallel_size, (
|
||||||
|
"prefill context parallel group already initialized, but of unexpected size: "
|
||||||
|
f"{pcp_world_size=} vs. "
|
||||||
|
f"{prefill_context_model_parallel_size=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||||
@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if _TP is not None:
|
if _TP is not None:
|
||||||
_TP.prepare_communication_buffer_for_model(model)
|
_TP.prepare_communication_buffer_for_model(model)
|
||||||
|
if _PCP is not None:
|
||||||
|
_PCP.prepare_communication_buffer_for_model(model)
|
||||||
if _PP is not None:
|
if _PP is not None:
|
||||||
_PP.prepare_communication_buffer_for_model(model)
|
_PP.prepare_communication_buffer_for_model(model)
|
||||||
if _DP is not None:
|
if _DP is not None:
|
||||||
@ -1520,16 +1563,21 @@ def destroy_model_parallel():
|
|||||||
_TP.destroy()
|
_TP.destroy()
|
||||||
_TP = None
|
_TP = None
|
||||||
|
|
||||||
global _PP
|
|
||||||
if _PP:
|
|
||||||
_PP.destroy()
|
|
||||||
_PP = None
|
|
||||||
|
|
||||||
global _DCP
|
global _DCP
|
||||||
if _DCP:
|
if _DCP:
|
||||||
_DCP.destroy()
|
_DCP.destroy()
|
||||||
_DCP = None
|
_DCP = None
|
||||||
|
|
||||||
|
global _PCP
|
||||||
|
if _PCP:
|
||||||
|
_PCP.destroy()
|
||||||
|
_PCP = None
|
||||||
|
|
||||||
|
global _PP
|
||||||
|
if _PP:
|
||||||
|
_PP.destroy()
|
||||||
|
_PP = None
|
||||||
|
|
||||||
global _DP
|
global _DP
|
||||||
if _DP:
|
if _DP:
|
||||||
_DP.destroy()
|
_DP.destroy()
|
||||||
|
|||||||
@ -389,8 +389,10 @@ class EngineArgs:
|
|||||||
nnodes: int = ParallelConfig.nnodes
|
nnodes: int = ParallelConfig.nnodes
|
||||||
node_rank: int = ParallelConfig.node_rank
|
node_rank: int = ParallelConfig.node_rank
|
||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
|
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
||||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||||
|
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
|
||||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||||
data_parallel_rank: int | None = None
|
data_parallel_rank: int | None = None
|
||||||
data_parallel_start_rank: int | None = None
|
data_parallel_start_rank: int | None = None
|
||||||
@ -770,6 +772,15 @@ class EngineArgs:
|
|||||||
"--dcp-kv-cache-interleave-size",
|
"--dcp-kv-cache-interleave-size",
|
||||||
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--cp-kv-cache-interleave-size",
|
||||||
|
**parallel_kwargs["cp_kv_cache_interleave_size"],
|
||||||
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--prefill-context-parallel-size",
|
||||||
|
"-pcp",
|
||||||
|
**parallel_kwargs["prefill_context_parallel_size"],
|
||||||
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
||||||
)
|
)
|
||||||
@ -1600,6 +1611,7 @@ class EngineArgs:
|
|||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
|
prefill_context_parallel_size=self.prefill_context_parallel_size,
|
||||||
data_parallel_size=self.data_parallel_size,
|
data_parallel_size=self.data_parallel_size,
|
||||||
data_parallel_rank=self.data_parallel_rank or 0,
|
data_parallel_rank=self.data_parallel_rank or 0,
|
||||||
data_parallel_external_lb=data_parallel_external_lb,
|
data_parallel_external_lb=data_parallel_external_lb,
|
||||||
@ -1631,6 +1643,7 @@ class EngineArgs:
|
|||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||||
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||||
|
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
|
||||||
_api_process_count=self._api_process_count,
|
_api_process_count=self._api_process_count,
|
||||||
_api_process_rank=self._api_process_rank,
|
_api_process_rank=self._api_process_rank,
|
||||||
)
|
)
|
||||||
@ -1952,6 +1965,15 @@ class EngineArgs:
|
|||||||
default_prefix_caching,
|
default_prefix_caching,
|
||||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||||
|
|
||||||
|
if self.prefill_context_parallel_size > 1:
|
||||||
|
default_chunked_prefill = False
|
||||||
|
default_prefix_caching = False
|
||||||
|
logger.warning(
|
||||||
|
"--prefill-context-parallel-size > 1 is not compatible with "
|
||||||
|
"chunked prefill and prefix caching now. Chunked prefill "
|
||||||
|
"and prefix caching have been disabled by default."
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_chunked_prefill is None:
|
if self.enable_chunked_prefill is None:
|
||||||
self.enable_chunked_prefill = default_chunked_prefill
|
self.enable_chunked_prefill = default_chunked_prefill
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,11 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
from vllm.distributed import (
|
||||||
|
get_dp_group,
|
||||||
|
get_pcp_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||||
OCP_MX_DTYPES,
|
OCP_MX_DTYPES,
|
||||||
@ -684,9 +688,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FusedMoEParallelConfig:
|
class FusedMoEParallelConfig:
|
||||||
tp_size: int
|
tp_size: int
|
||||||
|
pcp_size: int
|
||||||
dp_size: int
|
dp_size: int
|
||||||
ep_size: int
|
ep_size: int
|
||||||
tp_rank: int
|
tp_rank: int
|
||||||
|
pcp_rank: int
|
||||||
dp_rank: int
|
dp_rank: int
|
||||||
ep_rank: int
|
ep_rank: int
|
||||||
|
|
||||||
@ -713,19 +719,22 @@ class FusedMoEParallelConfig:
|
|||||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flatten_tp_across_dp(
|
def flatten_tp_across_dp_and_pcp(
|
||||||
tp_size: int, dp_size: int, dp_rank: int
|
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
|
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
|
||||||
# There are actually dp_size * tp_size devices. Update tp_size
|
# There are actually dp_size * pcp_size * tp_size devices.
|
||||||
# and tp_rank so we shard across all devices.
|
# Update tp_size and tp_rank so we shard across all devices.
|
||||||
flatten_tp_size = dp_size * tp_size
|
flatten_tp_size = dp_size * pcp_size * tp_size
|
||||||
flatten_tp_rank = dp_rank * tp_size + tp_rank
|
flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
|
||||||
return flatten_tp_size, flatten_tp_rank
|
return flatten_tp_size, flatten_tp_rank
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
|
tp_size_: int,
|
||||||
|
pcp_size_: int,
|
||||||
|
dp_size_: int,
|
||||||
|
vllm_parallel_config: ParallelConfig,
|
||||||
) -> "FusedMoEParallelConfig":
|
) -> "FusedMoEParallelConfig":
|
||||||
"""
|
"""
|
||||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||||
@ -734,19 +743,22 @@ class FusedMoEParallelConfig:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||||
|
pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
|
||||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||||
object which contains the `enable_expert_parallel` flag.
|
object which contains the `enable_expert_parallel` flag.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
When there is no parallelism requested,
|
When there is no parallelism requested,
|
||||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||||
unaltered and the ranks set to 0.
|
unaltered and the ranks set to 0.
|
||||||
|
|
||||||
Expert Parallelism is considered only when either `dp_size_` or
|
Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
|
||||||
`tp_size_` is non trivial.
|
`tp_size_` is non trivial.
|
||||||
|
|
||||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
Note that PCP serves the same function as DP here.
|
||||||
|
|
||||||
|
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||||
@ -754,7 +766,7 @@ class FusedMoEParallelConfig:
|
|||||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||||
- Comment : Tensors are sharded across 2 devices.
|
- Comment : Tensors are sharded across 2 devices.
|
||||||
|
|
||||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||||
@ -762,7 +774,7 @@ class FusedMoEParallelConfig:
|
|||||||
- Comment: There are 2 engine instances and the tensors are sharded
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
across 2 decvices.
|
across 2 decvices.
|
||||||
|
|
||||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||||
@ -772,14 +784,14 @@ class FusedMoEParallelConfig:
|
|||||||
- Comment: There are 2 engine instances and the tensors are sharded
|
- Comment: There are 2 engine instances and the tensors are sharded
|
||||||
across 4 devices.
|
across 4 devices.
|
||||||
|
|
||||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||||
- Comment: The experts are split between the 2 devices.
|
- Comment: The experts are split between the 2 devices.
|
||||||
|
|
||||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||||
@ -787,7 +799,7 @@ class FusedMoEParallelConfig:
|
|||||||
- Comment: There are 2 engine instances and the experts are split
|
- Comment: There are 2 engine instances and the experts are split
|
||||||
between the 2 devices.
|
between the 2 devices.
|
||||||
|
|
||||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
|
||||||
devices:
|
devices:
|
||||||
|
|
||||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||||
@ -798,18 +810,25 @@ class FusedMoEParallelConfig:
|
|||||||
between the 4 devices.
|
between the 4 devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
|
use_ep = (
|
||||||
|
dp_size_ * pcp_size_ * tp_size_ > 1
|
||||||
|
and vllm_parallel_config.enable_expert_parallel
|
||||||
|
)
|
||||||
|
|
||||||
dp_size = dp_size_
|
dp_size = dp_size_
|
||||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
pcp_size = pcp_size_
|
||||||
tp_size_, dp_size_, dp_rank
|
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
|
||||||
|
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||||
|
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
if not use_ep:
|
if not use_ep:
|
||||||
return FusedMoEParallelConfig(
|
return FusedMoEParallelConfig(
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
|
pcp_size=pcp_size,
|
||||||
|
pcp_rank=pcp_rank,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
ep_size=1,
|
ep_size=1,
|
||||||
@ -826,6 +845,8 @@ class FusedMoEParallelConfig:
|
|||||||
return FusedMoEParallelConfig(
|
return FusedMoEParallelConfig(
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
tp_rank=0,
|
tp_rank=0,
|
||||||
|
pcp_size=pcp_size,
|
||||||
|
pcp_rank=pcp_rank,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
ep_size=ep_size,
|
ep_size=ep_size,
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
|
|||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_dp_group,
|
get_dp_group,
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
|
get_pcp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
@ -343,6 +344,7 @@ class FusedMoE(CustomOp):
|
|||||||
tp_size: int | None = None,
|
tp_size: int | None = None,
|
||||||
ep_size: int | None = None,
|
ep_size: int | None = None,
|
||||||
dp_size: int | None = None,
|
dp_size: int | None = None,
|
||||||
|
pcp_size: int | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Callable | None = None,
|
custom_routing_function: Callable | None = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
@ -398,12 +400,14 @@ class FusedMoE(CustomOp):
|
|||||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||||
)
|
)
|
||||||
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
|
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
|
||||||
|
pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
|
||||||
|
|
||||||
self.is_sequence_parallel = is_sequence_parallel
|
self.is_sequence_parallel = is_sequence_parallel
|
||||||
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||||
|
|
||||||
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||||
tp_size_=tp_size_,
|
tp_size_=tp_size_,
|
||||||
|
pcp_size_=pcp_size_,
|
||||||
dp_size_=dp_size_,
|
dp_size_=dp_size_,
|
||||||
vllm_parallel_config=vllm_config.parallel_config,
|
vllm_parallel_config=vllm_config.parallel_config,
|
||||||
)
|
)
|
||||||
@ -679,6 +683,10 @@ class FusedMoE(CustomOp):
|
|||||||
def dp_size(self):
|
def dp_size(self):
|
||||||
return self.moe_parallel_config.dp_size
|
return self.moe_parallel_config.dp_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pcp_size(self):
|
||||||
|
return self.moe_parallel_config.pcp_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ep_size(self):
|
def ep_size(self):
|
||||||
return self.moe_parallel_config.ep_size
|
return self.moe_parallel_config.ep_size
|
||||||
@ -691,6 +699,10 @@ class FusedMoE(CustomOp):
|
|||||||
def dp_rank(self):
|
def dp_rank(self):
|
||||||
return self.moe_parallel_config.dp_rank
|
return self.moe_parallel_config.dp_rank
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pcp_rank(self):
|
||||||
|
return self.moe_parallel_config.pcp_rank
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ep_rank(self):
|
def ep_rank(self):
|
||||||
return self.moe_parallel_config.ep_rank
|
return self.moe_parallel_config.ep_rank
|
||||||
@ -1871,6 +1883,19 @@ class FusedMoE(CustomOp):
|
|||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
|
||||||
|
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||||
|
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||||
|
# we should modify All2AllManager abstract to better support PCP.
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
hidden_states = get_pcp_group().all_gather(
|
||||||
|
hidden_states,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
router_logits = get_pcp_group().all_gather(
|
||||||
|
router_logits,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@ -1925,6 +1950,13 @@ class FusedMoE(CustomOp):
|
|||||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||||
if do_naive_dispatch_combine:
|
if do_naive_dispatch_combine:
|
||||||
states = get_ep_group().combine(states, self.is_sequence_parallel)
|
states = get_ep_group().combine(states, self.is_sequence_parallel)
|
||||||
|
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
states = get_pcp_group().reduce_scatter(
|
||||||
|
states,
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
|
|||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_dp_group,
|
get_dp_group,
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
|
get_pcp_group,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@ -322,10 +323,12 @@ class GptOssModel(nn.Module):
|
|||||||
|
|
||||||
# In MoE, we need to flatten the tensor parallel size across the data
|
# In MoE, we need to flatten the tensor parallel size across the data
|
||||||
# parallel size when EP is disabled.
|
# parallel size when EP is disabled.
|
||||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||||
tp_size=get_tensor_model_parallel_world_size(),
|
tp_size=get_tensor_model_parallel_world_size(),
|
||||||
dp_size=get_dp_group().world_size,
|
dp_size=get_dp_group().world_size,
|
||||||
dp_rank=get_dp_group().rank_in_group,
|
dp_rank=get_dp_group().rank_in_group,
|
||||||
|
pcp_size=get_pcp_group().world_size,
|
||||||
|
pcp_rank=get_pcp_group().rank_in_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_size = self.config.intermediate_size
|
intermediate_size = self.config.intermediate_size
|
||||||
@ -507,10 +510,12 @@ class GptOssModel(nn.Module):
|
|||||||
|
|
||||||
# In MoE, we need to flatten the tensor parallel size across the data
|
# In MoE, we need to flatten the tensor parallel size across the data
|
||||||
# parallel size when EP is disabled.
|
# parallel size when EP is disabled.
|
||||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||||
tp_size=get_tensor_model_parallel_world_size(),
|
tp_size=get_tensor_model_parallel_world_size(),
|
||||||
dp_size=get_dp_group().world_size,
|
dp_size=get_dp_group().world_size,
|
||||||
dp_rank=get_dp_group().rank_in_group,
|
dp_rank=get_dp_group().rank_in_group,
|
||||||
|
pcp_size=get_pcp_group().world_size,
|
||||||
|
pcp_rank=get_pcp_group().rank_in_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
intermediate_size = self.config.intermediate_size
|
intermediate_size = self.config.intermediate_size
|
||||||
|
|||||||
@ -265,8 +265,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
|
|
||||||
self.dcp_kv_cache_interleave_size = (
|
self.cp_kv_cache_interleave_size = (
|
||||||
self.parallel_config.dcp_kv_cache_interleave_size
|
self.parallel_config.cp_kv_cache_interleave_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_full_cuda_graph = (
|
self.use_full_cuda_graph = (
|
||||||
@ -388,7 +388,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
dcp_context_kv_lens_cpu,
|
dcp_context_kv_lens_cpu,
|
||||||
self.dcp_world_size,
|
self.dcp_world_size,
|
||||||
self.dcp_rank,
|
self.dcp_rank,
|
||||||
self.dcp_kv_cache_interleave_size,
|
self.cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||||
|
|||||||
@ -536,7 +536,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
||||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||||
|
|
||||||
# Don't try to access the runner on AMD
|
# Don't try to access the runner on AMD
|
||||||
@ -1289,8 +1289,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
get_current_vllm_config()
|
get_current_vllm_config()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.dcp_kv_cache_interleave_size: int = (
|
self.cp_kv_cache_interleave_size: int = (
|
||||||
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
|
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def _flash_attn_varlen_diff_headdims(
|
def _flash_attn_varlen_diff_headdims(
|
||||||
|
|||||||
@ -1080,9 +1080,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
|||||||
|
|
||||||
def get_dcp_local_seq_lens(
|
def get_dcp_local_seq_lens(
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
dcp_world_size: int = 1,
|
dcp_size: int = 1,
|
||||||
dcp_rank: int | None = None,
|
dcp_rank: int | None = None,
|
||||||
dcp_kv_cache_interleave_size: int = 1,
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""While using dcp, kv_cache size stored on each rank may be different,
|
"""While using dcp, kv_cache size stored on each rank may be different,
|
||||||
use this function to calculate split decode seq_lens of each dcp rank.
|
use this function to calculate split decode seq_lens of each dcp rank.
|
||||||
@ -1091,7 +1091,7 @@ def get_dcp_local_seq_lens(
|
|||||||
num_requests = seq_lens.size(0)
|
num_requests = seq_lens.size(0)
|
||||||
if dcp_rank is None:
|
if dcp_rank is None:
|
||||||
rank_offsets = (
|
rank_offsets = (
|
||||||
torch.arange(dcp_world_size, dtype=torch.int32)
|
torch.arange(dcp_size, dtype=torch.int32)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(num_requests, 1)
|
.repeat(num_requests, 1)
|
||||||
)
|
)
|
||||||
@ -1102,15 +1102,15 @@ def get_dcp_local_seq_lens(
|
|||||||
)
|
)
|
||||||
base = (
|
base = (
|
||||||
seq_lens_tiled
|
seq_lens_tiled
|
||||||
// dcp_kv_cache_interleave_size
|
// cp_kv_cache_interleave_size
|
||||||
// dcp_world_size
|
// dcp_size
|
||||||
* dcp_kv_cache_interleave_size
|
* cp_kv_cache_interleave_size
|
||||||
)
|
)
|
||||||
remainder = seq_lens_tiled - base * dcp_world_size
|
remainder = seq_lens_tiled - base * dcp_size
|
||||||
remainder = torch.clip(
|
remainder = torch.clip(
|
||||||
remainder - rank_offsets * dcp_kv_cache_interleave_size,
|
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||||
0,
|
0,
|
||||||
dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
dcp_local_seq_lens = base + remainder
|
dcp_local_seq_lens = base + remainder
|
||||||
return dcp_local_seq_lens.squeeze(1)
|
return dcp_local_seq_lens.squeeze(1)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
pcp_world_size: int,
|
||||||
):
|
):
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_group_id=i,
|
kv_cache_group_id=i,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||||
)
|
)
|
||||||
@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
pcp_world_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
False,
|
False,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
self.num_single_type_manager = len(self.single_type_managers)
|
self.num_single_type_manager = len(self.single_type_managers)
|
||||||
|
|
||||||
@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
pcp_world_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||||
self.block_size = self.kv_cache_spec.block_size
|
self.block_size = self.kv_cache_spec.block_size
|
||||||
self.dcp_world_size = dcp_world_size
|
self.dcp_world_size = dcp_world_size
|
||||||
|
self.pcp_world_size = pcp_world_size
|
||||||
if dcp_world_size > 1:
|
if dcp_world_size > 1:
|
||||||
self.block_size *= dcp_world_size
|
self.block_size *= dcp_world_size
|
||||||
|
if pcp_world_size > 1:
|
||||||
|
self.block_size *= pcp_world_size
|
||||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||||
)
|
)
|
||||||
@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
kv_cache_spec=self.kv_cache_spec,
|
kv_cache_spec=self.kv_cache_spec,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
dcp_world_size=self.dcp_world_size,
|
dcp_world_size=self.dcp_world_size,
|
||||||
|
pcp_world_size=self.pcp_world_size,
|
||||||
)
|
)
|
||||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||||
|
|
||||||
@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
pcp_world_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||||
|
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||||
self.verify_and_split_kv_cache_groups()
|
self.verify_and_split_kv_cache_groups()
|
||||||
|
|
||||||
def verify_and_split_kv_cache_groups(self) -> None:
|
def verify_and_split_kv_cache_groups(self) -> None:
|
||||||
@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
|
|||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
|
pcp_world_size: int,
|
||||||
) -> KVCacheCoordinator:
|
) -> KVCacheCoordinator:
|
||||||
if not enable_caching:
|
if not enable_caching:
|
||||||
return KVCacheCoordinatorNoPrefixCache(
|
return KVCacheCoordinatorNoPrefixCache(
|
||||||
@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
|
|||||||
use_eagle,
|
use_eagle,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||||
return UnitaryKVCacheCoordinator(
|
return UnitaryKVCacheCoordinator(
|
||||||
@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
|
|||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
return HybridKVCacheCoordinator(
|
return HybridKVCacheCoordinator(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
|
|||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -100,6 +100,7 @@ class KVCacheManager:
|
|||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
|
|
||||||
@ -124,12 +125,9 @@ class KVCacheManager:
|
|||||||
0
|
0
|
||||||
].kv_cache_spec.block_size
|
].kv_cache_spec.block_size
|
||||||
|
|
||||||
if dcp_world_size > 1:
|
if dcp_world_size * pcp_world_size > 1:
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
# Note(hc): need revisit. When both DCP and any future
|
self.block_size *= dcp_world_size * pcp_world_size
|
||||||
# PCP are enabled, the block_size may need to be scaled
|
|
||||||
# by a factor of dcp_size × pcp_size?
|
|
||||||
self.block_size *= dcp_world_size
|
|
||||||
|
|
||||||
self.coordinator = get_kv_cache_coordinator(
|
self.coordinator = get_kv_cache_coordinator(
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
@ -138,6 +136,7 @@ class KVCacheManager:
|
|||||||
enable_caching=self.enable_caching,
|
enable_caching=self.enable_caching,
|
||||||
enable_kv_cache_events=enable_kv_cache_events,
|
enable_kv_cache_events=enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
|
pcp_world_size=pcp_world_size,
|
||||||
)
|
)
|
||||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||||
self.block_pool = self.coordinator.block_pool
|
self.block_pool = self.coordinator.block_pool
|
||||||
|
|||||||
@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
|
|||||||
// len(kv_cache_config.kv_cache_groups)
|
// len(kv_cache_config.kv_cache_groups)
|
||||||
* min_block_size
|
* min_block_size
|
||||||
)
|
)
|
||||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
|
if pcp_size * dcp_size > 1:
|
||||||
|
num_tokens *= pcp_size * dcp_size
|
||||||
logger.info(
|
logger.info(
|
||||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
"Multiplying the GPU KV cache size by the cp_world_size %d "
|
||||||
vllm_config.parallel_config.decode_context_parallel_size,
|
"(pcp_world_size %d * dcp_world_size %d).",
|
||||||
|
pcp_size * dcp_size,
|
||||||
|
pcp_size,
|
||||||
|
dcp_size,
|
||||||
)
|
)
|
||||||
num_tokens_str = f"{num_tokens:,}"
|
num_tokens_str = f"{num_tokens:,}"
|
||||||
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
|
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
|
||||||
|
|||||||
@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
|
|
||||||
# req_id -> Request
|
# req_id -> Request
|
||||||
self.requests: dict[str, Request] = {}
|
self.requests: dict[str, Request] = {}
|
||||||
@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
dcp_world_size=self.dcp_world_size,
|
dcp_world_size=self.dcp_world_size,
|
||||||
|
pcp_world_size=self.pcp_world_size,
|
||||||
)
|
)
|
||||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_group_id: int,
|
kv_cache_group_id: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the SingleTypeKVCacheManager.
|
Initializes the SingleTypeKVCacheManager.
|
||||||
@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
"""
|
"""
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.dcp_world_size = dcp_world_size
|
self.dcp_world_size = dcp_world_size
|
||||||
if self.dcp_world_size > 1:
|
self.pcp_world_size = pcp_world_size
|
||||||
self.block_size *= dcp_world_size
|
if dcp_world_size * pcp_world_size > 1:
|
||||||
|
self.block_size *= dcp_world_size * pcp_world_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_pool = block_pool
|
self.block_pool = block_pool
|
||||||
|
|
||||||
@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks that is not longer than
|
Get the longest cache hit prefix of the blocks that is not longer than
|
||||||
@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||||
@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
[] for _ in range(len(kv_cache_group_ids))
|
[] for _ in range(len(kv_cache_group_ids))
|
||||||
)
|
)
|
||||||
block_size = kv_cache_spec.block_size
|
block_size = kv_cache_spec.block_size
|
||||||
if dcp_world_size > 1:
|
if dcp_world_size * pcp_world_size > 1:
|
||||||
block_size *= dcp_world_size
|
block_size *= dcp_world_size * pcp_world_size
|
||||||
max_num_blocks = max_length // block_size
|
max_num_blocks = max_length // block_size
|
||||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||||
"SlidingWindowManager can only be used for sliding window groups"
|
"SlidingWindowManager can only be used for sliding window groups"
|
||||||
)
|
)
|
||||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||||
|
assert pcp_world_size == 1, "PCP not support sliding window attn now."
|
||||||
|
|
||||||
# The number of contiguous blocks needed for prefix cache hit.
|
# The number of contiguous blocks needed for prefix cache hit.
|
||||||
# -1 since the input token itself is also included in the window
|
# -1 since the input token itself is also included in the window
|
||||||
@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
For chunked local attention, we need to find the longest cache hit
|
For chunked local attention, we need to find the longest cache hit
|
||||||
@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||||
)
|
)
|
||||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||||
|
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
if max_length > 0:
|
if max_length > 0:
|
||||||
local_attention_start_idx = (
|
local_attention_start_idx = (
|
||||||
@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, MambaSpec), (
|
assert isinstance(kv_cache_spec, MambaSpec), (
|
||||||
"MambaManager can only be used for mamba groups"
|
"MambaManager can only be used for mamba groups"
|
||||||
)
|
)
|
||||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||||
|
assert pcp_world_size == 1, "PCP not support mamba now."
|
||||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
[] for _ in range(len(kv_cache_group_ids))
|
[] for _ in range(len(kv_cache_group_ids))
|
||||||
)
|
)
|
||||||
@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||||
"CrossAttentionManager can only be used for cross-attention groups"
|
"CrossAttentionManager can only be used for cross-attention groups"
|
||||||
|
|||||||
@ -128,6 +128,7 @@ class EngineCore:
|
|||||||
scheduler_block_size = (
|
scheduler_block_size = (
|
||||||
vllm_config.cache_config.block_size
|
vllm_config.cache_config.block_size
|
||||||
* vllm_config.parallel_config.decode_context_parallel_size
|
* vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
* vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler: SchedulerInterface = Scheduler(
|
self.scheduler: SchedulerInterface = Scheduler(
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_dp_group,
|
get_dp_group,
|
||||||
get_ep_group,
|
get_ep_group,
|
||||||
get_inner_dp_world_group,
|
get_inner_dp_world_group,
|
||||||
|
get_pcp_group,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
|
|||||||
f"({self.parallel_config.nnodes_within_dp}). "
|
f"({self.parallel_config.nnodes_within_dp}). "
|
||||||
)
|
)
|
||||||
self.local_world_size = self.parallel_config.local_world_size
|
self.local_world_size = self.parallel_config.local_world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
pp_size = self.parallel_config.pipeline_parallel_size
|
||||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
pcp_size = self.parallel_config.prefill_context_parallel_size
|
||||||
|
assert self.world_size == tp_size * pp_size * pcp_size, (
|
||||||
f"world_size ({self.world_size}) must be equal to the "
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
f"tensor_parallel_size ({tp_size}) x pipeline"
|
||||||
f"_parallel_size ({pp_parallel_size}). "
|
f"_parallel_size ({pp_size}) x prefill_context"
|
||||||
|
f"_parallel_size ({pcp_size}). "
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set multiprocessing envs
|
# Set multiprocessing envs
|
||||||
@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
|
|||||||
# 16-23, PP rank 2
|
# 16-23, PP rank 2
|
||||||
# 24-31, PP rank 3
|
# 24-31, PP rank 3
|
||||||
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
||||||
return self.world_size - self.parallel_config.tensor_parallel_size
|
return (
|
||||||
|
self.world_size
|
||||||
|
- self.parallel_config.tensor_parallel_size
|
||||||
|
* self.parallel_config.prefill_context_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -828,6 +835,8 @@ class WorkerProc:
|
|||||||
dp_rank = get_dp_group().rank_in_group
|
dp_rank = get_dp_group().rank_in_group
|
||||||
pp_size = get_pp_group().world_size
|
pp_size = get_pp_group().world_size
|
||||||
pp_rank = get_pp_group().rank_in_group
|
pp_rank = get_pp_group().rank_in_group
|
||||||
|
pcp_size = get_pcp_group().world_size
|
||||||
|
pcp_rank = get_pcp_group().rank_in_group
|
||||||
tp_size = get_tp_group().world_size
|
tp_size = get_tp_group().world_size
|
||||||
tp_rank = get_tp_group().rank_in_group
|
tp_rank = get_tp_group().rank_in_group
|
||||||
dcp_size = get_dcp_group().world_size
|
dcp_size = get_dcp_group().world_size
|
||||||
@ -837,6 +846,8 @@ class WorkerProc:
|
|||||||
process_name += f"_DP{dp_rank}"
|
process_name += f"_DP{dp_rank}"
|
||||||
if pp_size > 1:
|
if pp_size > 1:
|
||||||
process_name += f"_PP{pp_rank}"
|
process_name += f"_PP{pp_rank}"
|
||||||
|
if pcp_size > 1:
|
||||||
|
process_name += f"_PCP{pcp_rank}"
|
||||||
if tp_size > 1:
|
if tp_size > 1:
|
||||||
process_name += f"_TP{tp_rank}"
|
process_name += f"_TP{tp_rank}"
|
||||||
if dcp_size > 1:
|
if dcp_size > 1:
|
||||||
|
|||||||
@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||||
# Note(hc): each dcp rank only need save
|
# Note(hc): each dcp rank only need save
|
||||||
# (max_model_len//dcp_world_size) tokens locally.
|
# (max_model_len//dcp_world_size) tokens locally.
|
||||||
if dcp_world_size > 1:
|
if dcp_world_size * pcp_world_size > 1:
|
||||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
|
||||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import get_dcp_group
|
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
@ -22,7 +22,7 @@ class BlockTable:
|
|||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
kernel_block_size: int,
|
kernel_block_size: int,
|
||||||
dcp_kv_cache_interleave_size: int,
|
cp_kv_cache_interleave_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -80,6 +80,13 @@ class BlockTable:
|
|||||||
else:
|
else:
|
||||||
self._kernel_block_arange = None
|
self._kernel_block_arange = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.pcp_world_size = get_pcp_group().world_size
|
||||||
|
self.pcp_rank = get_pcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.pcp_world_size = 1
|
||||||
|
self.pcp_rank = 0
|
||||||
try:
|
try:
|
||||||
self.dcp_world_size = get_dcp_group().world_size
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
self.dcp_rank = get_dcp_group().rank_in_group
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
@ -87,7 +94,7 @@ class BlockTable:
|
|||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||||
|
|
||||||
def append_row(
|
def append_row(
|
||||||
self,
|
self,
|
||||||
@ -131,14 +138,16 @@ class BlockTable:
|
|||||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||||
# here because M (max_model_len) is not necessarily divisible by
|
# here because M (max_model_len) is not necessarily divisible by
|
||||||
# block_size.
|
# block_size.
|
||||||
if self.dcp_world_size > 1:
|
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||||
|
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||||
|
if total_cp_world_size > 1:
|
||||||
# Note(hc): The DCP implement store kvcache with an interleave
|
# Note(hc): The DCP implement store kvcache with an interleave
|
||||||
# style, the kvcache for the token whose token_idx is i is
|
# style, the kvcache for the token whose token_idx is i is
|
||||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||||
|
|
||||||
# Use a "virtual block" which equals to world_size * block_size
|
# Use a "virtual block" which equals to world_size * block_size
|
||||||
# for block_table_indices calculation.
|
# for block_table_indices calculation.
|
||||||
virtual_block_size = self.block_size * self.dcp_world_size
|
virtual_block_size = self.block_size * total_cp_world_size
|
||||||
block_table_indices = (
|
block_table_indices = (
|
||||||
req_indices * self.max_num_blocks_per_req
|
req_indices * self.max_num_blocks_per_req
|
||||||
+ positions // virtual_block_size
|
+ positions // virtual_block_size
|
||||||
@ -150,16 +159,16 @@ class BlockTable:
|
|||||||
virtual_block_offsets = positions % virtual_block_size
|
virtual_block_offsets = positions % virtual_block_size
|
||||||
mask = (
|
mask = (
|
||||||
virtual_block_offsets
|
virtual_block_offsets
|
||||||
// self.dcp_kv_cache_interleave_size
|
// self.cp_kv_cache_interleave_size
|
||||||
% self.dcp_world_size
|
% total_cp_world_size
|
||||||
== self.dcp_rank
|
== total_cp_rank
|
||||||
)
|
)
|
||||||
# Calculate local block_offsets
|
# Calculate local block_offsets
|
||||||
block_offsets = (
|
block_offsets = (
|
||||||
virtual_block_offsets
|
virtual_block_offsets
|
||||||
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
|
||||||
* self.dcp_kv_cache_interleave_size
|
* self.cp_kv_cache_interleave_size
|
||||||
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||||
)
|
)
|
||||||
# Calculate slot_mapping
|
# Calculate slot_mapping
|
||||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
@ -253,7 +262,7 @@ class MultiGroupBlockTable:
|
|||||||
block_sizes: list[int],
|
block_sizes: list[int],
|
||||||
kernel_block_sizes: list[int],
|
kernel_block_sizes: list[int],
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
dcp_kv_cache_interleave_size: int = 1,
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Note(hc): each dcp rank only store
|
# Note(hc): each dcp rank only store
|
||||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||||
@ -283,7 +292,7 @@ class MultiGroupBlockTable:
|
|||||||
pin_memory,
|
pin_memory,
|
||||||
device,
|
device,
|
||||||
kernel_block_size,
|
kernel_block_size,
|
||||||
dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class InputBatch:
|
|||||||
is_spec_decode: bool = False,
|
is_spec_decode: bool = False,
|
||||||
is_pooling_model: bool = False,
|
is_pooling_model: bool = False,
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
dcp_kv_cache_interleave_size: int = 1,
|
cp_kv_cache_interleave_size: int = 1,
|
||||||
):
|
):
|
||||||
self.is_pooling_model = is_pooling_model
|
self.is_pooling_model = is_pooling_model
|
||||||
self.is_spec_decode = is_spec_decode
|
self.is_spec_decode = is_spec_decode
|
||||||
@ -141,7 +141,7 @@ class InputBatch:
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
kernel_block_sizes=kernel_block_sizes,
|
kernel_block_sizes=kernel_block_sizes,
|
||||||
num_speculative_tokens=num_speculative_tokens,
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
|
|||||||
@ -426,7 +426,7 @@ class GPUModelRunner(
|
|||||||
# uses output token ids so we set this conservatively.
|
# uses output token ids so we set this conservatively.
|
||||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||||
is_pooling_model=self.is_pooling_model,
|
is_pooling_model=self.is_pooling_model,
|
||||||
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
|
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||||
@ -1436,7 +1436,7 @@ class GPUModelRunner(
|
|||||||
self.seq_lens.cpu[:num_reqs],
|
self.seq_lens.cpu[:num_reqs],
|
||||||
self.dcp_world_size,
|
self.dcp_world_size,
|
||||||
self.dcp_rank,
|
self.dcp_rank,
|
||||||
self.parallel_config.dcp_kv_cache_interleave_size,
|
self.parallel_config.cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
|
|||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
)
|
)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_pcp_group,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
@ -733,6 +734,7 @@ class Worker(WorkerBase):
|
|||||||
module.global_num_experts = module.moe_config.num_experts
|
module.global_num_experts = module.moe_config.num_experts
|
||||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||||
tp_size_=get_tp_group().world_size,
|
tp_size_=get_tp_group().world_size,
|
||||||
|
pcp_size_=get_pcp_group().world_size,
|
||||||
dp_size_=get_dp_group().world_size,
|
dp_size_=get_dp_group().world_size,
|
||||||
vllm_parallel_config=parallel_config,
|
vllm_parallel_config=parallel_config,
|
||||||
)
|
)
|
||||||
@ -886,6 +888,7 @@ def init_worker_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size,
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.prefill_context_parallel_size,
|
||||||
parallel_config.decode_context_parallel_size,
|
parallel_config.decode_context_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user