mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:14:34 +08:00
[Feature] Prefill Context Parallel (PCP) basic support (#28718)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com> Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com> Co-authored-by: LookAround <lixushi@huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
parent
02f5903b84
commit
2fd893b4ce
@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
dcp_size: int
|
||||
dcp_kv_cache_interleave_size: int
|
||||
cp_kv_cache_interleave_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
@ -55,7 +55,7 @@ class CPTestSettings:
|
||||
tp_base: int = 4,
|
||||
pp_base: int = 1,
|
||||
dcp_base: int = 1,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: str | None = None,
|
||||
@ -71,7 +71,7 @@ class CPTestSettings:
|
||||
tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=int(dcp_multiplier * tp_base),
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val,
|
||||
)
|
||||
@ -116,7 +116,7 @@ def _compare_cp_with_tp(
|
||||
tp_size,
|
||||
pp_size,
|
||||
dcp_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
@ -197,7 +197,7 @@ def _compare_cp_with_tp(
|
||||
"--decode-context-parallel-size",
|
||||
str(dcp_size),
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
str(dcp_kv_cache_interleave_size),
|
||||
str(cp_kv_cache_interleave_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
@ -227,7 +227,7 @@ CP_TEXT_GENERATION_MODELS = {
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||
CPTestSettings.detailed(),
|
||||
CPTestSettings.detailed(tp_base=2),
|
||||
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
|
||||
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
|
||||
],
|
||||
"bigcode/gpt_bigcode-santacoder": [
|
||||
CPTestSettings.detailed(),
|
||||
|
||||
@ -15,7 +15,11 @@ from tests.kernels.quantization.nvfp4_utils import (
|
||||
)
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@ -561,6 +565,7 @@ def make_modular_kernel(
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
|
||||
@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
|
||||
max_num_reqs = 10
|
||||
max_num_blocks_per_req = 20
|
||||
max_num_batched_tokens = 512
|
||||
dcp_kv_cache_interleave_size = 8
|
||||
cp_kv_cache_interleave_size = 8
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
|
||||
pin_memory=False,
|
||||
device=torch.device(DEVICE),
|
||||
kernel_block_size=kernel_block_sizes[0],
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Verify hybrid block configuration
|
||||
|
||||
@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
dcp_world_size: int
|
||||
dcp_rank: int
|
||||
|
||||
pcp_world_size: int
|
||||
pcp_rank: int
|
||||
|
||||
total_cp_world_size: int
|
||||
total_cp_rank: int
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# use __new__ so that all subclasses will call this
|
||||
self = super().__new__(cls)
|
||||
@ -278,6 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
try:
|
||||
from vllm.distributed.parallel_state import get_pcp_group
|
||||
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
|
||||
self.need_to_return_lse_for_decode = (
|
||||
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
||||
)
|
||||
|
||||
@ -169,12 +169,11 @@ def correct_attn_out(
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(
|
||||
def _cp_lse_common(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext = None,
|
||||
return_lse=False,
|
||||
ctx: CPTritonContext | None = None,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
|
||||
if return_lse:
|
||||
@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
|
||||
return out
|
||||
|
||||
|
||||
def cp_lse_ag_out_ar(
|
||||
cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext | None = None,
|
||||
return_lse: bool = False,
|
||||
):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
|
||||
out = cp_group.all_reduce(out)
|
||||
|
||||
if return_lse:
|
||||
return out, lse
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
|
||||
@ -71,6 +71,8 @@ class ParallelConfig:
|
||||
"""Number of pipeline parallel groups."""
|
||||
tensor_parallel_size: int = 1
|
||||
"""Number of tensor parallel groups."""
|
||||
prefill_context_parallel_size: int = 1
|
||||
"""Number of prefill context parallel groups."""
|
||||
data_parallel_size: int = 1
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
@ -239,14 +241,25 @@ class ParallelConfig:
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
dcp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using dcp or cp > 1,
|
||||
store interleave_size tokens on (d)cp i,
|
||||
then store next interleave_size tokens on (d)cp i+1.
|
||||
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
|
||||
Interleave_size=block_size: block-level align, first fill the block on first rank,
|
||||
token is stored on rank i+1 block j after rank i block j is full.
|
||||
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
|
||||
Block_size should be divisible by dcp_kv_cache_interleave_size.
|
||||
"""
|
||||
Interleave size of kv_cache storage while using DCP.
|
||||
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
|
||||
and will be deprecated when PCP is fully supported.
|
||||
|
||||
"""
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
|
||||
store interleave_size tokens on total_cp_rank i,
|
||||
then store next interleave_size tokens on taotal_cp_rank i+1.
|
||||
Interleave_size=1: token-level alignment, where token `i` is stored on
|
||||
total_cp_rank `i % total_cp_world_size`.
|
||||
Interleave_size=block_size: block-level alignment, where tokens are
|
||||
first populated to the preceding ranks. Tokens are then stored
|
||||
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
|
||||
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
|
||||
Block_size should be divisible by cp_kv_cache_interleave_size.
|
||||
"""
|
||||
|
||||
_api_process_count: int = Field(default=1, gt=0)
|
||||
@ -311,6 +324,11 @@ class ParallelConfig:
|
||||
"num_redundant_experts."
|
||||
)
|
||||
|
||||
if self.prefill_context_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Prefill context parallelism is not fully supported. "
|
||||
"Please set prefill_context_parallel_size to 1."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
@ -529,7 +547,11 @@ class ParallelConfig:
|
||||
)
|
||||
|
||||
# Continue with the rest of the initialization
|
||||
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
self.world_size = (
|
||||
self.pipeline_parallel_size
|
||||
* self.tensor_parallel_size
|
||||
* self.prefill_context_parallel_size
|
||||
)
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
|
||||
@ -481,6 +481,14 @@ class VllmConfig:
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
# prefill context parallel do not support full cudagraphs
|
||||
elif self.parallel_config.prefill_context_parallel_size > 1:
|
||||
logger.warning_once(
|
||||
"Prefill context parallel (PCP) is enabled, which is "
|
||||
"incompatible with full CUDA graphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
elif self.model_config is not None:
|
||||
if self.model_config.pooler_config is not None:
|
||||
logger.warning_once(
|
||||
@ -610,22 +618,34 @@ class VllmConfig:
|
||||
|
||||
# If DCP, ensure the block size is right.
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
!= self.parallel_config.dcp_kv_cache_interleave_size
|
||||
):
|
||||
self.parallel_config.cp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
logger.warning_once(
|
||||
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
|
||||
"_interleave_size. And dcp-kv-cache-interleave-size will be "
|
||||
"deprecated when PCP is fully supported."
|
||||
)
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
<= self.cache_config.block_size
|
||||
and self.cache_config.block_size
|
||||
% self.parallel_config.dcp_kv_cache_interleave_size
|
||||
% self.parallel_config.cp_kv_cache_interleave_size
|
||||
== 0
|
||||
), (
|
||||
f"Block_size({self.cache_config.block_size}) should be greater "
|
||||
"than or equal to and divisible by dcp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
||||
"than or equal to and divisible by cp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size == 1
|
||||
self.parallel_config.cp_kv_cache_interleave_size == 1
|
||||
or self.speculative_config is None
|
||||
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
|
||||
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
||||
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
|
||||
@ -1098,6 +1098,12 @@ get_context_model_parallel_group = get_dcp_group
|
||||
|
||||
_PP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_pp_group() -> GroupCoordinator:
|
||||
assert _PP is not None, "pipeline model parallel group is not initialized"
|
||||
return _PP
|
||||
|
||||
|
||||
_DP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator:
|
||||
return _EP
|
||||
|
||||
|
||||
def get_pp_group() -> GroupCoordinator:
|
||||
assert _PP is not None, "pipeline model parallel group is not initialized"
|
||||
return _PP
|
||||
_PCP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_pcp_group() -> GroupCoordinator:
|
||||
assert _PCP is not None, "prefill context parallel group is not initialized"
|
||||
return _PCP
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -1276,6 +1285,7 @@ def init_distributed_environment(
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
prefill_context_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: int | None = 1,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
@ -1325,7 +1335,11 @@ def initialize_model_parallel(
|
||||
# to get group_ranks for each dimension, transpose that dimension to the
|
||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
|
||||
-1,
|
||||
data_parallel_size,
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
tensor_model_parallel_size,
|
||||
) # noqa
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
@ -1360,11 +1374,23 @@ def initialize_model_parallel(
|
||||
group_name="dcp",
|
||||
)
|
||||
|
||||
global _PCP
|
||||
assert _PCP is None, "prefill context parallel group is already initialized"
|
||||
group_ranks = (
|
||||
all_ranks.transpose(3, 4)
|
||||
.reshape(-1, prefill_context_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PCP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PP
|
||||
assert _PP is None, "pipeline model parallel group is already initialized"
|
||||
group_ranks = (
|
||||
all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(
|
||||
@ -1373,7 +1399,7 @@ def initialize_model_parallel(
|
||||
|
||||
global _DP
|
||||
assert _DP is None, "data parallel group is already initialized"
|
||||
group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
@ -1383,7 +1409,12 @@ def initialize_model_parallel(
|
||||
assert _EP is None, "expert parallel group is already initialized"
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(-1, data_parallel_size * tensor_model_parallel_size)
|
||||
.reshape(
|
||||
-1,
|
||||
data_parallel_size
|
||||
* prefill_context_model_parallel_size
|
||||
* tensor_model_parallel_size,
|
||||
)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
@ -1393,11 +1424,13 @@ def initialize_model_parallel(
|
||||
|
||||
logger.info_once(
|
||||
"rank %s in world size %s is assigned as "
|
||||
"DP rank %s, PP rank %s, TP rank %s, EP rank %s",
|
||||
"DP rank %s, PP rank %s, PCP rank %s, "
|
||||
"TP rank %s, EP rank %s",
|
||||
rank,
|
||||
world_size,
|
||||
_DP.rank_in_group,
|
||||
_PP.rank_in_group,
|
||||
_PCP.rank_in_group,
|
||||
_TP.rank_in_group,
|
||||
_EP.rank_in_group,
|
||||
)
|
||||
@ -1406,6 +1439,7 @@ def initialize_model_parallel(
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
prefill_context_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: int | None = 1,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized(
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
decode_context_model_parallel_size,
|
||||
backend,
|
||||
)
|
||||
@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized(
|
||||
f"got: {pp_world_size=} vs. "
|
||||
f"wanted: {pipeline_model_parallel_size=}"
|
||||
)
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
assert pcp_world_size == prefill_context_model_parallel_size, (
|
||||
"prefill context parallel group already initialized, but of unexpected size: "
|
||||
f"{pcp_world_size=} vs. "
|
||||
f"{prefill_context_model_parallel_size=}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||
@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||
"""
|
||||
if _TP is not None:
|
||||
_TP.prepare_communication_buffer_for_model(model)
|
||||
if _PCP is not None:
|
||||
_PCP.prepare_communication_buffer_for_model(model)
|
||||
if _PP is not None:
|
||||
_PP.prepare_communication_buffer_for_model(model)
|
||||
if _DP is not None:
|
||||
@ -1520,16 +1563,21 @@ def destroy_model_parallel():
|
||||
_TP.destroy()
|
||||
_TP = None
|
||||
|
||||
global _PP
|
||||
if _PP:
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
global _DCP
|
||||
if _DCP:
|
||||
_DCP.destroy()
|
||||
_DCP = None
|
||||
|
||||
global _PCP
|
||||
if _PCP:
|
||||
_PCP.destroy()
|
||||
_PCP = None
|
||||
|
||||
global _PP
|
||||
if _PP:
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
global _DP
|
||||
if _DP:
|
||||
_DP.destroy()
|
||||
|
||||
@ -389,8 +389,10 @@ class EngineArgs:
|
||||
nnodes: int = ParallelConfig.nnodes
|
||||
node_rank: int = ParallelConfig.node_rank
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: int | None = None
|
||||
data_parallel_start_rank: int | None = None
|
||||
@ -770,6 +772,15 @@ class EngineArgs:
|
||||
"--dcp-kv-cache-interleave-size",
|
||||
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--cp-kv-cache-interleave-size",
|
||||
**parallel_kwargs["cp_kv_cache_interleave_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--prefill-context-parallel-size",
|
||||
"-pcp",
|
||||
**parallel_kwargs["prefill_context_parallel_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
||||
)
|
||||
@ -1600,6 +1611,7 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
prefill_context_parallel_size=self.prefill_context_parallel_size,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
data_parallel_rank=self.data_parallel_rank or 0,
|
||||
data_parallel_external_lb=data_parallel_external_lb,
|
||||
@ -1631,6 +1643,7 @@ class EngineArgs:
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
|
||||
_api_process_count=self._api_process_count,
|
||||
_api_process_rank=self._api_process_rank,
|
||||
)
|
||||
@ -1952,6 +1965,15 @@ class EngineArgs:
|
||||
default_prefix_caching,
|
||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||
|
||||
if self.prefill_context_parallel_size > 1:
|
||||
default_chunked_prefill = False
|
||||
default_prefix_caching = False
|
||||
logger.warning(
|
||||
"--prefill-context-parallel-size > 1 is not compatible with "
|
||||
"chunked prefill and prefix caching now. Chunked prefill "
|
||||
"and prefix caching have been disabled by default."
|
||||
)
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = default_chunked_prefill
|
||||
|
||||
|
||||
@ -8,7 +8,11 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_DTYPES,
|
||||
@ -684,9 +688,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
pcp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
pcp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
@ -713,19 +719,22 @@ class FusedMoEParallelConfig:
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@staticmethod
|
||||
def flatten_tp_across_dp(
|
||||
tp_size: int, dp_size: int, dp_rank: int
|
||||
def flatten_tp_across_dp_and_pcp(
|
||||
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
|
||||
) -> tuple[int, int]:
|
||||
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size * tp_size devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
flatten_tp_size = dp_size * tp_size
|
||||
flatten_tp_rank = dp_rank * tp_size + tp_rank
|
||||
# There are actually dp_size * pcp_size * tp_size devices.
|
||||
# Update tp_size and tp_rank so we shard across all devices.
|
||||
flatten_tp_size = dp_size * pcp_size * tp_size
|
||||
flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
|
||||
return flatten_tp_size, flatten_tp_rank
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
|
||||
tp_size_: int,
|
||||
pcp_size_: int,
|
||||
dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig,
|
||||
) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
@ -734,19 +743,22 @@ class FusedMoEParallelConfig:
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
Note that PCP serves the same function as DP here.
|
||||
|
||||
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
@ -754,7 +766,7 @@ class FusedMoEParallelConfig:
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
@ -762,7 +774,7 @@ class FusedMoEParallelConfig:
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
@ -772,14 +784,14 @@ class FusedMoEParallelConfig:
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
@ -787,7 +799,7 @@ class FusedMoEParallelConfig:
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
@ -798,18 +810,25 @@ class FusedMoEParallelConfig:
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
|
||||
use_ep = (
|
||||
dp_size_ * pcp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size_, dp_size_, dp_rank
|
||||
pcp_size = pcp_size_
|
||||
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
|
||||
)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
pcp_size=pcp_size,
|
||||
pcp_rank=pcp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
@ -826,6 +845,8 @@ class FusedMoEParallelConfig:
|
||||
return FusedMoEParallelConfig(
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
pcp_size=pcp_size,
|
||||
pcp_rank=pcp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
@ -343,6 +344,7 @@ class FusedMoE(CustomOp):
|
||||
tp_size: int | None = None,
|
||||
ep_size: int | None = None,
|
||||
dp_size: int | None = None,
|
||||
pcp_size: int | None = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
@ -398,12 +400,14 @@ class FusedMoE(CustomOp):
|
||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
|
||||
pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size_,
|
||||
pcp_size_=pcp_size_,
|
||||
dp_size_=dp_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
@ -679,6 +683,10 @@ class FusedMoE(CustomOp):
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def pcp_size(self):
|
||||
return self.moe_parallel_config.pcp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
@ -691,6 +699,10 @@ class FusedMoE(CustomOp):
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def pcp_rank(self):
|
||||
return self.moe_parallel_config.pcp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
@ -1871,6 +1883,19 @@ class FusedMoE(CustomOp):
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstract to better support PCP.
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
@ -1925,6 +1950,13 @@ class FusedMoE(CustomOp):
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(states, self.is_sequence_parallel)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
states = get_pcp_group().reduce_scatter(
|
||||
states,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -322,10 +323,12 @@ class GptOssModel(nn.Module):
|
||||
|
||||
# In MoE, we need to flatten the tensor parallel size across the data
|
||||
# parallel size when EP is disabled.
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
dp_rank=get_dp_group().rank_in_group,
|
||||
pcp_size=get_pcp_group().world_size,
|
||||
pcp_rank=get_pcp_group().rank_in_group,
|
||||
)
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
@ -507,10 +510,12 @@ class GptOssModel(nn.Module):
|
||||
|
||||
# In MoE, we need to flatten the tensor parallel size across the data
|
||||
# parallel size when EP is disabled.
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
dp_rank=get_dp_group().rank_in_group,
|
||||
pcp_size=get_pcp_group().world_size,
|
||||
pcp_rank=get_pcp_group().rank_in_group,
|
||||
)
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
|
||||
@ -265,8 +265,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
self.dcp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
self.cp_kv_cache_interleave_size = (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
@ -388,7 +388,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
dcp_context_kv_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
self.cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||
|
||||
@ -536,7 +536,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
||||
self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
|
||||
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
@ -1289,8 +1289,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self.dcp_kv_cache_interleave_size: int = (
|
||||
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
|
||||
self.cp_kv_cache_interleave_size: int = (
|
||||
get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
|
||||
@ -1080,9 +1080,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
|
||||
def get_dcp_local_seq_lens(
|
||||
seq_lens: torch.Tensor,
|
||||
dcp_world_size: int = 1,
|
||||
dcp_size: int = 1,
|
||||
dcp_rank: int | None = None,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each dcp rank.
|
||||
@ -1091,7 +1091,7 @@ def get_dcp_local_seq_lens(
|
||||
num_requests = seq_lens.size(0)
|
||||
if dcp_rank is None:
|
||||
rank_offsets = (
|
||||
torch.arange(dcp_world_size, dtype=torch.int32)
|
||||
torch.arange(dcp_size, dtype=torch.int32)
|
||||
.unsqueeze(0)
|
||||
.repeat(num_requests, 1)
|
||||
)
|
||||
@ -1102,15 +1102,15 @@ def get_dcp_local_seq_lens(
|
||||
)
|
||||
base = (
|
||||
seq_lens_tiled
|
||||
// dcp_kv_cache_interleave_size
|
||||
// dcp_world_size
|
||||
* dcp_kv_cache_interleave_size
|
||||
// cp_kv_cache_interleave_size
|
||||
// dcp_size
|
||||
* cp_kv_cache_interleave_size
|
||||
)
|
||||
remainder = seq_lens_tiled - base * dcp_world_size
|
||||
remainder = seq_lens_tiled - base * dcp_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * dcp_kv_cache_interleave_size,
|
||||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||
0,
|
||||
dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = base + remainder
|
||||
return dcp_local_seq_lens.squeeze(1)
|
||||
|
||||
@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
)
|
||||
@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
use_eagle: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
False,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
if pcp_world_size > 1:
|
||||
self.block_size *= pcp_world_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
|
||||
@ -100,6 +100,7 @@ class KVCacheManager:
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
@ -124,12 +125,9 @@ class KVCacheManager:
|
||||
0
|
||||
].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size > 1:
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
# Note(hc): need revisit. When both DCP and any future
|
||||
# PCP are enabled, the block_size may need to be scaled
|
||||
# by a factor of dcp_size × pcp_size?
|
||||
self.block_size *= dcp_world_size
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
@ -138,6 +136,7 @@ class KVCacheManager:
|
||||
enable_caching=self.enable_caching,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
|
||||
@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
|
||||
// len(kv_cache_config.kv_cache_groups)
|
||||
* min_block_size
|
||||
)
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
num_tokens *= pcp_size * dcp_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size,
|
||||
"Multiplying the GPU KV cache size by the cp_world_size %d "
|
||||
"(pcp_world_size %d * dcp_world_size %d).",
|
||||
pcp_size * dcp_size,
|
||||
pcp_size,
|
||||
dcp_size,
|
||||
)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
|
||||
|
||||
@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
self.block_size = block_size
|
||||
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: dict[str, Request] = {}
|
||||
@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
if dcp_world_size > 1:
|
||||
block_size *= dcp_world_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
block_size *= dcp_world_size * pcp_world_size
|
||||
max_num_blocks = max_length // block_size
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
assert pcp_world_size == 1, "PCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, MambaSpec), (
|
||||
"MambaManager can only be used for mamba groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
assert pcp_world_size == 1, "PCP not support mamba now."
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
|
||||
@ -128,6 +128,7 @@ class EngineCore:
|
||||
scheduler_block_size = (
|
||||
vllm_config.cache_config.block_size
|
||||
* vllm_config.parallel_config.decode_context_parallel_size
|
||||
* vllm_config.parallel_config.prefill_context_parallel_size
|
||||
)
|
||||
|
||||
self.scheduler: SchedulerInterface = Scheduler(
|
||||
|
||||
@ -35,6 +35,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_inner_dp_world_group,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
@ -110,12 +111,14 @@ class MultiprocExecutor(Executor):
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
pp_size = self.parallel_config.pipeline_parallel_size
|
||||
pcp_size = self.parallel_config.prefill_context_parallel_size
|
||||
assert self.world_size == tp_size * pp_size * pcp_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). "
|
||||
f"tensor_parallel_size ({tp_size}) x pipeline"
|
||||
f"_parallel_size ({pp_size}) x prefill_context"
|
||||
f"_parallel_size ({pcp_size}). "
|
||||
)
|
||||
|
||||
# Set multiprocessing envs
|
||||
@ -424,7 +427,11 @@ class MultiprocExecutor(Executor):
|
||||
# 16-23, PP rank 2
|
||||
# 24-31, PP rank 3
|
||||
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
||||
return self.world_size - self.parallel_config.tensor_parallel_size
|
||||
return (
|
||||
self.world_size
|
||||
- self.parallel_config.tensor_parallel_size
|
||||
* self.parallel_config.prefill_context_parallel_size
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -828,6 +835,8 @@ class WorkerProc:
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
pp_size = get_pp_group().world_size
|
||||
pp_rank = get_pp_group().rank_in_group
|
||||
pcp_size = get_pcp_group().world_size
|
||||
pcp_rank = get_pcp_group().rank_in_group
|
||||
tp_size = get_tp_group().world_size
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
dcp_size = get_dcp_group().world_size
|
||||
@ -837,6 +846,8 @@ class WorkerProc:
|
||||
process_name += f"_DP{dp_rank}"
|
||||
if pp_size > 1:
|
||||
process_name += f"_PP{pp_rank}"
|
||||
if pcp_size > 1:
|
||||
process_name += f"_PCP{pcp_rank}"
|
||||
if tp_size > 1:
|
||||
process_name += f"_TP{tp_rank}"
|
||||
if dcp_size > 1:
|
||||
|
||||
@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec):
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
@ -22,7 +22,7 @@ class BlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
dcp_kv_cache_interleave_size: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -80,6 +80,13 @@ class BlockTable:
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@ -87,7 +94,7 @@ class BlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@ -131,14 +138,16 @@ class BlockTable:
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
if self.dcp_world_size > 1:
|
||||
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
if total_cp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
virtual_block_size = self.block_size * total_cp_world_size
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
@ -150,16 +159,16 @@ class BlockTable:
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.dcp_kv_cache_interleave_size
|
||||
% self.dcp_world_size
|
||||
== self.dcp_rank
|
||||
// self.cp_kv_cache_interleave_size
|
||||
% total_cp_world_size
|
||||
== total_cp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
||||
* self.dcp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
||||
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
|
||||
* self.cp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
@ -253,7 +262,7 @@ class MultiGroupBlockTable:
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@ -283,7 +292,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
@ -87,7 +87,7 @@ class InputBatch:
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
@ -141,7 +141,7 @@ class InputBatch:
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
||||
@ -426,7 +426,7 @@ class GPUModelRunner(
|
||||
# uses output token ids so we set this conservatively.
|
||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
@ -1436,7 +1436,7 @@ class GPUModelRunner(
|
||||
self.seq_lens.cpu[:num_reqs],
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.parallel_config.dcp_kv_cache_interleave_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.distributed.kv_transfer import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
@ -733,6 +734,7 @@ class Worker(WorkerBase):
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tp_group().world_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
@ -886,6 +888,7 @@ def init_worker_distributed_environment(
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.prefill_context_parallel_size,
|
||||
parallel_config.decode_context_parallel_size,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user