[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)

Signed-off-by: hongchao <hongchao@msh.team>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: hongchao <hongchao@msh.team>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
yzds 2025-09-06 13:24:05 +08:00 committed by GitHub
parent 3c529fc994
commit ac201a0eaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 999 additions and 230 deletions

View File

@ -837,7 +837,7 @@ steps:
- pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Pipeline Parallelism Test # 45min - label: Pipeline + Context Parallelism Test # 45min
timeout_in_minutes: 60 timeout_in_minutes: 60
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@ -851,6 +851,7 @@ steps:
commands: commands:
- pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py - pytest -v -s distributed/test_pipeline_parallel.py
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
- label: LoRA TP Test (Distributed) # 17 min - label: LoRA TP Test (Distributed) # 17 min
timeout_in_minutes: 30 timeout_in_minutes: 30

View File

@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
torch::Tensor& scale); torch::Tensor& scale);
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& cp_local_token_select_indices,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype); const double scale, const std::string& kv_cache_dtype);

View File

@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void cp_fused_concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
const int64_t slot_idx = slot_mapping[blockIdx.x];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
} // namespace vllm } // namespace vllm
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
@ -554,20 +509,6 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr())); reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
cp_local_token_select_indices.data_ptr<int64_t>(), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla( void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim] torch::Tensor& k_pe, // [num_tokens, pe_dim]
@ -606,50 +547,6 @@ void concat_and_cache_mla(
CALL_CONCAT_AND_CACHE_MLA); CALL_CONCAT_AND_CACHE_MLA);
} }
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
// calls into one:
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
// concat_and_cache_mla.
void cp_fused_concat_and_cache_mla(
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
}
namespace vllm { namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>

View File

@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor scale) -> ()"); " Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
cache_ops.def(
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor cp_local_token_select_indices,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
&cp_fused_concat_and_cache_mla);
// Convert the key and value cache to fp8 data type. // Convert the key and value cache to fp8 data type.
cache_ops.def( cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "

View File

@ -0,0 +1,263 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional
import pytest
from vllm.config import RunnerOption
from vllm.logger import init_logger
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test
logger = init_logger("test_context_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
eager_mode: bool
chunked_prefill: bool
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: Optional[str] = None
@dataclass
class CPTestSettings:
parallel_setups: list[ParallelSetup]
# NOTE: the length of distributed_backends and
# vllm_major_versions should be the same, and they
# are first zipped together to iterate over all
# test settings.
distributed_backends: list[str]
# vllm major version: "0" for V0, "1" for V1
vllm_major_versions: list[str]
runner: RunnerOption
test_options: CPTestOptions
def __post_init__(self):
if len(self.distributed_backends) != len(self.vllm_major_versions):
raise ValueError(
f"Length mismatch: distributed_backends "
f"({len(self.distributed_backends)}) != "
f"vllm_major_versions ({len(self.vllm_major_versions)})")
@staticmethod
def detailed(
*,
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: Optional[str] = None,
):
parallel_setups = []
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [2, 4]:
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=dcp_multiplier * dcp_base,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val))
return CPTestSettings(
parallel_setups=parallel_setups,
distributed_backends=["mp"],
vllm_major_versions=["1"],
runner=runner,
test_options=CPTestOptions(multi_node_only=multi_node_only,
load_format=load_format),
)
def iter_params(self, model_id: str):
opts = self.test_options
for parallel_setup in self.parallel_setups:
for backend, vllm_major_version in zip(self.distributed_backends,
self.vllm_major_versions):
yield (model_id, parallel_setup, backend, vllm_major_version,
self.runner, opts)
def _compare_cp_with_tp(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available: int,
*,
method: Literal["generate"],
is_multimodal: bool,
):
(
tp_size,
pp_size,
dcp_size,
eager_mode,
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
if load_format == "dummy":
# Avoid OOM
text_overrides = {
"num_hidden_layers": 4,
"hidden_size": 512,
"intermediate_size": 800,
"num_attention_heads": 4,
"num_key_value_heads": 1,
}
if is_multimodal:
hf_overrides.update({"text_config": text_overrides})
else:
hf_overrides.update(text_overrides)
else:
model_info.check_available_online(on_fail="skip")
if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
if multi_node_only and not VLLM_MULTI_NODE:
pytest.skip("Not in multi-node setting")
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if runner != "auto":
common_args.extend(["--runner", runner])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if load_format:
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
cp_env = tp_env = {
"VLLM_USE_V1":
vllm_major_version, # Note(hc): DCP only support V1 engine only
}
cp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--distributed-executor-backend",
distributed_backend,
]
tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
"--distributed-executor-backend",
distributed_backend,
]
try:
compare_two_settings(model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720)
except Exception:
testing_ray_compiled_graph = cp_env is not None
if testing_ray_compiled_graph and vllm_major_version == "0":
# Ray Compiled Graph tests are flaky for V0,
# so we don't want to fail the test
logger.exception("Ray Compiled Graph tests failed")
else:
raise
CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(),
}
CP_TEST_MODELS = [
# TODO support other models
# [LANGUAGE GENERATION]
"deepseek-ai/DeepSeek-V2-Lite-Chat",
]
@pytest.mark.parametrize(
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
"runner", "test_options"),
[
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
for params in settings.iter_params(model_id)
if model_id in CP_TEST_MODELS
],
)
@create_new_process_for_each_test()
def test_cp_generation(
model_id: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
vllm_major_version: str,
runner: RunnerOption,
test_options: CPTestOptions,
num_gpus_available,
):
_compare_cp_with_tp(model_id,
parallel_setup,
distributed_backend,
vllm_major_version,
runner,
test_options,
num_gpus_available,
method="generate",
is_multimodal=False)

View File

@ -1625,20 +1625,6 @@ def concat_and_cache_mla(
scale) scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor], def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor], value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
vlse_ptr, outputs_stride_B, outputs_stride_H,
outputs_stride_D, lses_stride_N, lses_stride_B,
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
cp, batch, q_heads, v_head_dim
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
lse_max = tl.max(lse, axis=0)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = batch_idx * outputs_stride_B + \
head_idx * outputs_stride_H + \
d_offsets * outputs_stride_D
# correct output
lse_offset = lse_idx * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float('inf')),
-float('inf'), lse_finally)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
"""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
ctx: CPTritonContext):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
**const_args)
return out, lse
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
return out

View File

@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
descale_q, descale_q,
descale_k, descale_k,
) )
return out, softmax_lse
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
# #

View File

@ -170,6 +170,11 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users. Set to be private as it's not intended to be configured by users.
""" """
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
@property @property
def world_size_across_dp(self) -> int: def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world """world_size_across_dp is TPxPPxDP, it is the size of the world

View File

@ -904,6 +904,18 @@ def get_tensor_model_parallel_group():
return get_tp_group() return get_tp_group()
_DCP: Optional[GroupCoordinator] = None
def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, (
"decode context model parallel group is not initialized")
return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: Optional[GroupCoordinator] = None _PP: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None _DP: Optional[GroupCoordinator] = None
@ -1034,6 +1046,7 @@ def init_distributed_environment(
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -1098,6 +1111,23 @@ def initialize_model_parallel(
use_message_queue_broadcaster=True, use_message_queue_broadcaster=True,
group_name="tp") group_name="tp")
# Build the DCP model-parallel groups.
global _DCP
assert _DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuse the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PP global _PP
assert _PP is None, ( assert _PP is None, (
@ -1141,6 +1171,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group) get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend) pipeline_model_parallel_size,
decode_context_model_parallel_size, backend)
return return
assert ( assert (
@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int: def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """ """Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, ( assert _NODE_COUNT is not None, (
@ -1246,6 +1288,11 @@ def destroy_model_parallel():
_PP.destroy() _PP.destroy()
_PP = None _PP = None
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
global _DP global _DP
if _DP: if _DP:
_DP.destroy() _DP.destroy()

View File

@ -306,6 +306,8 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = \
ParallelConfig.decode_context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None
@ -636,6 +638,9 @@ class EngineArgs:
**parallel_kwargs["pipeline_parallel_size"]) **parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument("--tensor-parallel-size", "-tp", parallel_group.add_argument("--tensor-parallel-size", "-tp",
**parallel_kwargs["tensor_parallel_size"]) **parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument(
"--decode-context-parallel-size", "-dcp",
**parallel_kwargs["decode_context_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp", parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument( parallel_group.add_argument(
@ -1156,6 +1161,17 @@ class EngineArgs:
# global layers in interleaved sliding window models. # global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window() sliding_window = model_config.get_sliding_window()
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuse the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
assert self.tensor_parallel_size % self.decode_context_parallel_size \
== 0, (
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
@ -1306,6 +1322,7 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(

View File

@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
workspace: torch.Tensor workspace: torch.Tensor
# for mla DCP
cp_chunk_seq_lens: Optional[list[list[int]]] = None
origin_context_lens: Optional[list[int]] = None
cp_cu_seq_lens: Optional[torch.Tensor] = None
chunk_size: Optional[int] = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
block_table: torch.Tensor block_table: torch.Tensor
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_query_len: int max_query_len: int
@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
parallel_config) parallel_config)
self.mla_dims = get_mla_dims(self.model_config) self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda() self.aot_schedule = current_platform.is_cuda()
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
# Dont try to access the runner on AMD # Dont try to access the runner on AMD
if self.aot_schedule: if self.aot_schedule:
@ -465,6 +480,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
128 * 1024) 128 * 1024)
assert self.chunked_prefill_workspace_size >= \ assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size scheduler_config.max_num_seqs * cache_config.block_size
if self.dcp_world_size > 1:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# required, so the workspace has to be enlarged by 1/DCP relative
# to the original TP allocation.
assert self.chunked_prefill_workspace_size % \
self.dcp_world_size == 0
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size +
self.chunked_prefill_workspace_size // self.dcp_world_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
else:
self.chunked_prefill_workspace = torch.empty( self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size, (self.chunked_prefill_workspace_size,
self.model_config.get_head_size()), self.model_config.get_head_size()),
@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
split_decodes_and_prefills(common_attn_metadata, split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold) decode_threshold=self.reorder_batch_threshold)
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
// self.dcp_world_size + (self.dcp_rank <= \
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
assert num_decodes + num_prefills == num_reqs assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens assert num_decode_tokens + num_prefill_tokens == num_tokens
@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
self.dcp_world_size).int()
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item() max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[ prefill_query_start_loc = query_start_loc[
@ -691,14 +731,60 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
out=cu_seq_lens_cpu[:, 1:], out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32) dtype=torch.int32)
if self.dcp_world_size > 1:
# Note(hc): The above max_context_chunk already enforces
# block_size alignment, DCP just need the block_size can
# be divisible by dcp_world_size, because DCP use
# cp_gather_cache which not require `cp_chunk_starts`
# aligned to page_size.
assert max_context_chunk % self.dcp_world_size == 0
cp_max_context_chunk = max_context_chunk // \
self.dcp_world_size
cp_chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) \
* cp_max_context_chunk
cp_chunk_ends = torch.min(
cp_context_lens_cpu.unsqueeze(0),
cp_chunk_starts + cp_max_context_chunk)
cp_chunk_seq_lens = (cp_chunk_ends -
cp_chunk_starts).clamp(min=0)
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(cp_chunk_seq_lens,
dim=1,
out=cp_cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata_cls = \ chunked_context_metadata_cls = \
CudnnPrefillMetadata.ChunkedContextMetadata \ CudnnPrefillMetadata.ChunkedContextMetadata \
if self._use_cudnn_prefill else \ if self._use_cudnn_prefill else \
MLACommonPrefillMetadata.ChunkedContextMetadata MLACommonPrefillMetadata.ChunkedContextMetadata
if self.dcp_world_size > 1:
chunked_context_metadata = \ chunked_context_metadata = \
chunked_context_metadata_cls( chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
origin_context_lens=origin_context_lens,
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
.to(device, non_blocking=True),
chunk_size=max_context_chunk,
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
)
else:
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(), seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata return attn_metadata
def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
origin_context_lens):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks + src_token_idx +
real_cp_chunk_seq_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
real_cp_chunk_seq_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.vllm_flash_attn_version == 3 self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9) and current_platform.get_device_capability()[0] == 9)
self.dcp_world_size: Optional[int] = None
def _flash_attn_varlen_diff_headdims(self, def _flash_attn_varlen_diff_headdims(self,
q, q,
k, k,
@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return output, output_lse return output, output_lse
def _context_parallel_compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support sacled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.origin_context_lens is not None
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
assert prefill_metadata.chunked_context.chunk_size is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
# |<- use for loca_gather ->|<--------- use for allgather -------->|
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
assert allgather_offset * (dcp_world_size +
1) == workspace.shape[0]
assert toks <= allgather_offset
local_gathered_kvcache = workspace[:toks]
cur_allgather_workspace = workspace[
allgather_offset:allgather_offset * (1 + dcp_world_size)]
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
cur_allgather_kvcache = cur_allgather_workspace[:toks *
dcp_world_size]
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
local_gathered_kvcache, dim=0))
assert cur_allgather_kvcache.shape[
-1] == self.kv_lora_rank + self.qk_rope_head_dim
allgatered_kv_c_normed, allgatered_k_pe = \
cur_allgather_kvcache.unsqueeze(
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
cp_chunk_seq_lens[i],
origin_context_lens=prefill_metadata.chunked_context.
origin_context_lens,
cp_world_size=dcp_world_size,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
[-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill( def _forward_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_scale: torch.Tensor, k_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
has_context = attn_metadata.prefill.chunked_context is not None has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
@ -1181,7 +1437,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context: if has_context:
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ if self.dcp_world_size > 1:
context_output, context_lse = \
self._context_parallel_compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata,
k_scale=None, dcp_world_size=self.dcp_world_size)
else:
context_output, context_lse = \
self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale) q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
self, self,
ql_nope: torch.Tensor, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError raise NotImplementedError
def forward( def forward(
@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# same expert outputs. # same expert outputs.
return output.fill_(0) return output.fill_(0)
if self.dcp_world_size is None:
self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8") fp8_attention = self.kv_cache_dtype.startswith("fp8")
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
layer._q_scale) layer._q_scale)
decode_q_pe = decode_q_pe.reshape(q_pe_shape) decode_q_pe = decode_q_pe.reshape(q_pe_shape)
output[:num_decode_tokens] = self._forward_decode( decode_q = (decode_ql_nope, decode_q_pe)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
# call decode attn
attn_out, lse = self._forward_decode(decode_q, kv_cache,
attn_metadata, layer)
# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
assert lse is not None, (
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse.")
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection
output[:num_decode_tokens] = self._v_up_proj(attn_out)
return output_padded return output_padded

View File

@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
self._workspace.get_buf(), self._workspace.get_buf(),
self.scale, self._num_kv_splits) self.scale, self._num_kv_splits)
return self._v_up_proj(o) return o
# TODO: Currently we leave it here only for backup in case something is # TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel # wrong with the new SM100 CUTLASS MLA kernel
@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale) attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o) return o
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self._use_old_cutlass_mla: if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive # TODO: Remove the old cutlass MLA kernel after more extensive
# testing # testing
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata) attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata) attn_metadata), None

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar, Optional, Union
import torch import torch
@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata, attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError( raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported") "FP8 FlashAttention MLA not yet supported")

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar, Optional, Union
import torch import torch
@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\ if type(q) is tuple:
.unsqueeze(1) # Add seqlen dim of 1 (decode) q = torch.cat(q, dim=-1)
o, _ = flash_mla_with_kvcache( assert isinstance(q, torch.Tensor)
q=q, o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k=layer._k_scale.reshape(1), descale_k=layer._k_scale.reshape(1),
) )
return self._v_up_proj(o) return o, lse

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar, Optional, Union
import torch import torch
@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata, attn_metadata: AiterMLAMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
B = q_nope.shape[0] if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1) assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B, o = torch.zeros(B,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len) attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o) return o, None

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Union
import torch import torch
@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
layer: AttentionLayer, layer: AttentionLayer,
) -> torch.Tensor: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported") raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0] if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1) assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B, o = torch.zeros(B,
self.num_heads, self.num_heads,
self.kv_lora_rank, self.kv_lora_rank,
@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_logits, attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE) num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o) return o, None

View File

@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC):
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int,
): ):
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len self.max_model_len = max_model_len
@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec, kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_group_id=i, kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
) for i, kv_cache_group in enumerate( ) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups)) self.kv_cache_config.kv_cache_groups))
@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
""" """
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_kv_cache_events: bool): use_eagle: bool, enable_kv_cache_events: bool,
super().__init__(kv_cache_config, max_model_len, use_eagle, False, dcp_world_size: int):
enable_kv_cache_events) super().__init__(kv_cache_config,
max_model_len,
use_eagle,
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool, use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool): enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config, max_model_len, use_eagle, super().__init__(kv_cache_config,
enable_caching, enable_kv_cache_events) max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec 0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, ( assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group") "UnitaryKVCacheCoordinator assumes only one kv cache group")
@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec, kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
) )
return hit_blocks, len(hit_blocks[0]) * self.block_size return hit_blocks, len(hit_blocks[0]) * self.block_size
@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool, use_eagle: bool, enable_caching: bool,
enable_kv_cache_events: bool): enable_kv_cache_events: bool, dcp_world_size: int):
super().__init__(kv_cache_config, max_model_len, use_eagle, super().__init__(kv_cache_config,
enable_caching, enable_kv_cache_events) max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups() self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None: def verify_and_split_kv_cache_groups(self) -> None:
@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
return hit_blocks, hit_length return hit_blocks, hit_length
def get_kv_cache_coordinator( def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig,
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, max_model_len: int, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator: enable_kv_cache_events: bool,
dcp_world_size: int) -> KVCacheCoordinator:
if not enable_caching: if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, return KVCacheCoordinatorNoPrefixCache(kv_cache_config,
max_model_len,
use_eagle, use_eagle,
enable_kv_cache_events) enable_kv_cache_events,
dcp_world_size=dcp_world_size)
if len(kv_cache_config.kv_cache_groups) == 1: if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, return UnitaryKVCacheCoordinator(kv_cache_config,
use_eagle, enable_caching, max_model_len,
enable_kv_cache_events) use_eagle,
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching,
enable_caching, enable_kv_cache_events) enable_kv_cache_events,
dcp_world_size=dcp_world_size)
return HybridKVCacheCoordinator(kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size)

View File

@ -91,6 +91,7 @@ class KVCacheManager:
use_eagle: bool = False, use_eagle: bool = False,
log_stats: bool = False, log_stats: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
) -> None: ) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
@ -109,12 +110,20 @@ class KVCacheManager:
self.block_size = kv_cache_config.kv_cache_groups[ self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size 0].kv_cache_spec.block_size
if dcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.coordinator = get_kv_cache_coordinator( self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
enable_caching=self.enable_caching, enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
) )
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool

View File

@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
) )
num_tokens = num_blocks * vllm_config.cache_config.block_size num_tokens = num_blocks * vllm_config.cache_config.block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}" num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str) logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"

View File

@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface):
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
self.dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): The schedulers block_size must be multiplied
# by dcp_world_size, since block hashes are computed on the
# original full token sequence at a granularity of
# original_block_size × dcp_world_size.
if self.dcp_world_size > 1:
self.block_size *= self.dcp_world_size
# req_id -> Request # req_id -> Request
self.requests: dict[str, Request] = {} self.requests: dict[str, Request] = {}
# Scheduling policy # Scheduling policy
@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface):
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
) )
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1

View File

@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_group_id: int, kv_cache_group_id: int,
dcp_world_size: int = 1,
) -> None: ) -> None:
""" """
Initializes the SingleTypeKVCacheManager. Initializes the SingleTypeKVCacheManager.
@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool. block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager. kv_cache_group_id: The id of the kv cache group of this manager.
""" """
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1:
self.block_size *= dcp_world_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the longest cache hit prefix of the blocks that is not longer than Get the longest cache hit prefix of the blocks that is not longer than
@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
"and chunked local attention groups" "and chunked local attention groups"
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))) [] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size block_size = kv_cache_spec.block_size
if dcp_world_size > 1:
block_size *= dcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks): for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), ( assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups") "SlidingWindowManager can only be used for sliding window groups")
assert dcp_world_size == 1, "DCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit. # The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window # -1 since the input token itself is also included in the window
@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
For chunked local attention, we need to find the longest cache hit For chunked local attention, we need to find the longest cache hit
@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"chunked local attention groups") "chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " + assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.") "eagle + chunked local attention.")
assert dcp_world_size == 1, "DCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0: if max_length > 0:
local_attention_start_idx = (max_length // local_attention_start_idx = (max_length //
@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups") MambaSpec), ("MambaManager can only be used for mamba groups")
assert dcp_world_size == 1, "DCP not support mamba now."
# Prefix caching is not supported for mamba now. Always return empty # Prefix caching is not supported for mamba now. Always return empty
# list. # list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), ( assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups" "CrossAttentionManager can only be used for cross-attention groups"

View File

@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod @classmethod
@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
assert not self.use_mla, "MLA is not supported for sliding window" assert not self.use_mla, "MLA is not supported for sliding window"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
"DCP not support sliding window."
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = ( max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens) vllm_config.scheduler_config.max_num_batched_tokens)

View File

@ -4,6 +4,7 @@
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
@ -50,6 +51,13 @@ class BlockTable:
self.slot_mapping = torch.zeros(self.max_num_batched_tokens, self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
def append_row( def append_row(
self, self,
@ -89,6 +97,29 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size` # NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by # here because M (max_model_len) is not necessarily divisible by
# block_size. # block_size.
if self.dcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with a interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
# Calcuate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
# Calcuate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
else:
block_table_indices = (req_indices * self.max_num_blocks_per_req + block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // self.block_size) positions // self.block_size)
block_numbers = self.block_table_np.ravel()[block_table_indices] block_numbers = self.block_table_np.ravel()[block_table_indices]
@ -128,9 +159,19 @@ class MultiGroupBlockTable:
def __init__(self, max_num_reqs: int, max_model_len: int, def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool, max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_sizes: list[int]) -> None: device: torch.device, block_sizes: list[int]) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
self.block_tables = [ self.block_tables = [
BlockTable(block_size, max_num_reqs, cdiv(max_model_len, BlockTable(block_size, max_num_reqs,
block_size), cdiv(max_model_len, block_size * dcp_world_size),
max_num_batched_tokens, pin_memory, device) max_num_batched_tokens, pin_memory, device)
for block_size in block_sizes for block_size in block_sizes
] ]

View File

@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up, get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo) supports_dynamo)
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_config.is_multimodal_raw_input_only_model) model_config.is_multimodal_raw_input_only_model)
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
if self.reorder_batch_threshold is not None: if self.reorder_batch_threshold is not None:
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."
reorder_batch_to_split_decodes_and_prefills( reorder_batch_to_split_decodes_and_prefills(
self.input_batch, self.input_batch,
scheduler_output, scheduler_output,
@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops( get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks) copy_kv_blocks)
if self.dcp_world_size > 1:
assert self.attn_groups[0][0].backend is FlashMLABackend, (
"DCP only support flashmla now."
"For a mla backend want to enable DCP, it is mandatory that the"
"corresponding decode attn kernel return the softmax lse.")
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
""" """
Add encoder-only layers to the KV cache config. Add encoder-only layers to the KV cache config.

View File

@ -616,7 +616,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend) distributed_init_method, local_rank, backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(
parallel_config.pipeline_parallel_size) parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)

View File

@ -539,8 +539,10 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, distributed_init_method, local_rank,
current_platform.dist_backend) current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(
parallel_config.pipeline_parallel_size) parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)