mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3c529fc994
commit
ac201a0eaf
@ -837,7 +837,7 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
- label: Pipeline + Context Parallelism Test # 45min
|
||||
timeout_in_minutes: 60
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -851,6 +851,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
|
||||
|
||||
- label: LoRA TP Test (Distributed) # 17 min
|
||||
timeout_in_minutes: 30
|
||||
|
||||
@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
torch::Tensor& cp_local_token_select_indices,
|
||||
torch::Tensor& kv_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void cp_fused_concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
|
||||
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
|
||||
const int64_t slot_idx = slot_mapping[blockIdx.x];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
||||
int src_stride, int dst_stride, int size, int offset) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -554,20 +509,6 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
cp_local_token_select_indices.data_ptr<int64_t>(), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@ -606,50 +547,6 @@ void concat_and_cache_mla(
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
|
||||
// calls into one:
|
||||
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
|
||||
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
|
||||
// concat_and_cache_mla.
|
||||
void cp_fused_concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
|
||||
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
// both include padding.
|
||||
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
||||
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
||||
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
||||
// before padding.
|
||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||
// number of tokens.
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
|
||||
@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor cp_local_token_select_indices,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
|
||||
&cp_fused_concat_and_cache_mla);
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
|
||||
263
tests/distributed/test_context_parallel.py
Normal file
263
tests/distributed/test_context_parallel.py
Normal file
@ -0,0 +1,263 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
|
||||
important to set the distributed backend to "mp" to avoid Ray scheduling
|
||||
all workers in a node other than the head node, which can cause the test
|
||||
to fail.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
||||
logger = init_logger("test_context_parallel")
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
dcp_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class CPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
load_format: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPTestSettings:
|
||||
parallel_setups: list[ParallelSetup]
|
||||
# NOTE: the length of distributed_backends and
|
||||
# vllm_major_versions should be the same, and they
|
||||
# are first zipped together to iterate over all
|
||||
# test settings.
|
||||
distributed_backends: list[str]
|
||||
# vllm major version: "0" for V0, "1" for V1
|
||||
vllm_major_versions: list[str]
|
||||
runner: RunnerOption
|
||||
test_options: CPTestOptions
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
tp_base: int = 4,
|
||||
pp_base: int = 1,
|
||||
dcp_base: int = 1,
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for eager_mode_val in [False]:
|
||||
for pp_multiplier in [1]:
|
||||
for dcp_multiplier in [2, 4]:
|
||||
for chunked_prefill_val in [True]:
|
||||
parallel_setups.append(
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_multiplier * pp_base,
|
||||
dcp_size=dcp_multiplier * dcp_base,
|
||||
eager_mode=eager_mode_val,
|
||||
chunked_prefill=chunked_prefill_val))
|
||||
return CPTestSettings(
|
||||
parallel_setups=parallel_setups,
|
||||
distributed_backends=["mp"],
|
||||
vllm_major_versions=["1"],
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
)
|
||||
|
||||
def iter_params(self, model_id: str):
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.runner, opts)
|
||||
|
||||
|
||||
def _compare_cp_with_tp(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate"],
|
||||
is_multimodal: bool,
|
||||
):
|
||||
(
|
||||
tp_size,
|
||||
pp_size,
|
||||
dcp_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
text_overrides = {
|
||||
"num_hidden_layers": 4,
|
||||
"hidden_size": 512,
|
||||
"intermediate_size": 800,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 1,
|
||||
}
|
||||
|
||||
if is_multimodal:
|
||||
hf_overrides.update({"text_config": text_overrides})
|
||||
else:
|
||||
hf_overrides.update(text_overrides)
|
||||
else:
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
if multi_node_only and not VLLM_MULTI_NODE:
|
||||
pytest.skip("Not in multi-node setting")
|
||||
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if runner != "auto":
|
||||
common_args.extend(["--runner", runner])
|
||||
if trust_remote_code:
|
||||
common_args.append("--trust-remote-code")
|
||||
if tokenizer_mode:
|
||||
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||
if load_format:
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
cp_env = tp_env = {
|
||||
"VLLM_USE_V1":
|
||||
vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||
}
|
||||
|
||||
cp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--decode-context-parallel-size",
|
||||
str(dcp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720)
|
||||
except Exception:
|
||||
testing_ray_compiled_graph = cp_env is not None
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
# Ray Compiled Graph tests are flaky for V0,
|
||||
# so we don't want to fail the test
|
||||
logger.exception("Ray Compiled Graph tests failed")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
CP_TEXT_GENERATION_MODELS = {
|
||||
# [MLA attention only]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(),
|
||||
}
|
||||
|
||||
CP_TEST_MODELS = [
|
||||
# TODO support other models
|
||||
# [LANGUAGE GENERATION]
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
[
|
||||
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id)
|
||||
if model_id in CP_TEST_MODELS
|
||||
],
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_cp_generation(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: CPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_cp_with_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
method="generate",
|
||||
is_multimodal=False)
|
||||
@ -1625,20 +1625,6 @@ def concat_and_cache_mla(
|
||||
scale)
|
||||
|
||||
|
||||
def cp_fused_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
cp_local_token_select_indices: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
|
||||
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: list[torch.Tensor],
|
||||
value_caches: list[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
|
||||
139
vllm/attention/ops/common.py
Normal file
139
vllm/attention/ops/common.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
||||
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
||||
outputs_stride_D, lses_stride_N, lses_stride_B,
|
||||
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
||||
N_ROUNDED: tl.constexpr):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
output. we still need perform a cross-rank reduction to obtain the
|
||||
final attention output.
|
||||
|
||||
Args:
|
||||
output: [ B, H, D ]
|
||||
lses : [ N, B, H ]
|
||||
cp, batch, q_heads, v_head_dim
|
||||
Return:
|
||||
output: [ B, H, D ]
|
||||
lse : [ B, H ]
|
||||
"""
|
||||
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||
d_offsets = tl.arange(0, HEAD_DIM)
|
||||
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||
|
||||
# shape = [N]
|
||||
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
|
||||
# calc final lse
|
||||
lse = tl.load(lses_ptr + lse_offsets)
|
||||
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
||||
lse_max = tl.max(lse, axis=0)
|
||||
lse -= lse_max
|
||||
lse_exp = tl.exp(lse)
|
||||
lse_acc = tl.sum(lse_exp, axis=0)
|
||||
lse = tl.log(lse_acc)
|
||||
lse += lse_max
|
||||
|
||||
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||
tl.store(vlse_ptr + lse_offsets, lse)
|
||||
|
||||
# shape = [D]
|
||||
output_offsets = batch_idx * outputs_stride_B + \
|
||||
head_idx * outputs_stride_H + \
|
||||
d_offsets * outputs_stride_D
|
||||
|
||||
# correct output
|
||||
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
||||
lses_stride_B + head_idx * lses_stride_H
|
||||
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||
lse_finally = lse_tmp - lse
|
||||
lse_finally = tl.where(
|
||||
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
||||
-float('inf'), lse_finally)
|
||||
factor = tl.exp(lse_finally)
|
||||
output = tl.load(outputs_ptr + output_offsets)
|
||||
output = output * factor
|
||||
|
||||
tl.store(new_output_ptr + output_offsets, output)
|
||||
|
||||
|
||||
class CPTritonContext:
|
||||
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inner_kernel = None
|
||||
|
||||
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||
if self.inner_kernel is None:
|
||||
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||
else:
|
||||
self.inner_kernel[grid](*regular_args)
|
||||
|
||||
|
||||
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
||||
ctx: CPTritonContext):
|
||||
"""
|
||||
Apply the all-gathered lses to correct each local rank's attention
|
||||
output. we still need perform a cross-rank reduction to obtain the
|
||||
final attention output.
|
||||
|
||||
Args:
|
||||
output: [ B, H, D ]
|
||||
lses : [ N, B, H ]
|
||||
Return:
|
||||
output: [ B, H, D ]
|
||||
lse : [ B, H ]
|
||||
"""
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lse = torch.empty_like(lses[0])
|
||||
|
||||
grid = (out.shape[0], out.shape[1], 1)
|
||||
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
||||
cp_rank)
|
||||
const_args = {
|
||||
"HEAD_DIM": out.shape[-1],
|
||||
"N_ROUNDED": lses.shape[0],
|
||||
}
|
||||
|
||||
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
||||
**const_args)
|
||||
return out, lse
|
||||
|
||||
|
||||
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
cp_attn_lse: torch.Tensor,
|
||||
cp_group: GroupCoordinator,
|
||||
ctx: CPTritonContext = None):
|
||||
"""
|
||||
cp_attn_out: [ B, H, D ]
|
||||
cp_attn_lse: [ B, H ]
|
||||
"""
|
||||
if cp_group.world_size == 1:
|
||||
return cp_attn_out
|
||||
|
||||
if ctx is None:
|
||||
ctx = CPTritonContext()
|
||||
|
||||
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
||||
dtype=cp_attn_lse.dtype,
|
||||
device=cp_attn_lse.device)
|
||||
|
||||
cp_attn_lse = cp_attn_lse.contiguous()
|
||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
assert out.is_contiguous()
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
# Note(hc): need revisit when we support DCP with decode query_len > 1.
|
||||
return out.squeeze(1), softmax_lse.squeeze(-1)
|
||||
|
||||
|
||||
#
|
||||
|
||||
@ -170,6 +170,11 @@ class ParallelConfig:
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
|
||||
@ -904,6 +904,18 @@ def get_tensor_model_parallel_group():
|
||||
return get_tp_group()
|
||||
|
||||
|
||||
_DCP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_dcp_group() -> GroupCoordinator:
|
||||
assert _DCP is not None, (
|
||||
"decode context model parallel group is not initialized")
|
||||
return _DCP
|
||||
|
||||
|
||||
# kept for backward compatibility
|
||||
get_context_model_parallel_group = get_dcp_group
|
||||
|
||||
_PP: Optional[GroupCoordinator] = None
|
||||
|
||||
_DP: Optional[GroupCoordinator] = None
|
||||
@ -1034,6 +1046,7 @@ def init_distributed_environment(
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
decode_context_model_parallel_size: Optional[int] = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1098,6 +1111,23 @@ def initialize_model_parallel(
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="tp")
|
||||
|
||||
# Build the DCP model-parallel groups.
|
||||
global _DCP
|
||||
assert _DCP is None, (
|
||||
"decode context model parallel group is already initialized")
|
||||
# Note(hc): In the current implementation of decode context parallel,
|
||||
# dcp_size must not exceed tp_size, because the world size does not
|
||||
# change by DCP, it simply reuse the GPUs of TP group, and split one
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DCP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="dcp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PP
|
||||
assert _PP is None, (
|
||||
@ -1141,6 +1171,7 @@ def initialize_model_parallel(
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
decode_context_model_parallel_size: Optional[int] = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized(
|
||||
get_world_group().device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size, backend)
|
||||
pipeline_model_parallel_size,
|
||||
decode_context_model_parallel_size, backend)
|
||||
return
|
||||
|
||||
assert (
|
||||
@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank():
|
||||
return get_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_world_size():
|
||||
"""Return world size for the decode context model parallel group."""
|
||||
return get_dcp_group().world_size
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_rank():
|
||||
"""Return my rank for the decode context model parallel group."""
|
||||
return get_dcp_group().rank_in_group
|
||||
|
||||
|
||||
def get_node_count() -> int:
|
||||
"""Return the total number of nodes in the distributed environment. """
|
||||
assert _NODE_COUNT is not None, (
|
||||
@ -1246,6 +1288,11 @@ def destroy_model_parallel():
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
global _DCP
|
||||
if _DCP:
|
||||
_DCP.destroy()
|
||||
_DCP = None
|
||||
|
||||
global _DP
|
||||
if _DP:
|
||||
_DP.destroy()
|
||||
|
||||
@ -306,6 +306,8 @@ class EngineArgs:
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
decode_context_parallel_size: int = \
|
||||
ParallelConfig.decode_context_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: Optional[int] = None
|
||||
data_parallel_start_rank: Optional[int] = None
|
||||
@ -636,6 +638,9 @@ class EngineArgs:
|
||||
**parallel_kwargs["pipeline_parallel_size"])
|
||||
parallel_group.add_argument("--tensor-parallel-size", "-tp",
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
"--decode-context-parallel-size", "-dcp",
|
||||
**parallel_kwargs["decode_context_parallel_size"])
|
||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
@ -1156,6 +1161,17 @@ class EngineArgs:
|
||||
# global layers in interleaved sliding window models.
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
|
||||
# Note(hc): In the current implementation of decode context
|
||||
# parallel(DCP), tp_size needs to be divisible by dcp_size,
|
||||
# because the world size does not change by dcp, it simply
|
||||
# reuse the GPUs of TP group, and split one TP group into
|
||||
# tp_size//dcp_size DCP groups.
|
||||
assert self.tensor_parallel_size % self.decode_context_parallel_size \
|
||||
== 0, (
|
||||
f"tp_size={self.tensor_parallel_size} must be divisible by"
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@ -1306,6 +1322,7 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
|
||||
@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
workspace: torch.Tensor
|
||||
|
||||
# for mla DCP
|
||||
cp_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||
origin_context_lens: Optional[list[int]] = None
|
||||
cp_cu_seq_lens: Optional[torch.Tensor] = None
|
||||
chunk_size: Optional[int] = None
|
||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@ -465,12 +480,27 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||
# an additional kvcache allgather across the DCP group is therefore
|
||||
# required, so the workspace has to be enlarged by 1/DCP relative
|
||||
# to the original TP allocation.
|
||||
assert self.chunked_prefill_workspace_size % \
|
||||
self.dcp_world_size == 0
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size +
|
||||
self.chunked_prefill_workspace_size // self.dcp_world_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
# Note(hc): update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
|
||||
// self.dcp_world_size + (self.dcp_rank <= \
|
||||
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
# Note(hc): The context lengths in the perspective of dcp rank0.
|
||||
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
|
||||
self.dcp_world_size).int()
|
||||
origin_context_lens = context_lens_cpu.tolist()
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
@ -691,20 +731,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The above max_context_chunk already enforces
|
||||
# block_size alignment, DCP just need the block_size can
|
||||
# be divisible by dcp_world_size, because DCP use
|
||||
# cp_gather_cache which not require `cp_chunk_starts`
|
||||
# aligned to page_size.
|
||||
assert max_context_chunk % self.dcp_world_size == 0
|
||||
cp_max_context_chunk = max_context_chunk // \
|
||||
self.dcp_world_size
|
||||
cp_chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* cp_max_context_chunk
|
||||
cp_chunk_ends = torch.min(
|
||||
cp_context_lens_cpu.unsqueeze(0),
|
||||
cp_chunk_starts + cp_max_context_chunk)
|
||||
cp_chunk_seq_lens = (cp_chunk_ends -
|
||||
cp_chunk_starts).clamp(min=0)
|
||||
|
||||
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(cp_chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cp_cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata_cls = \
|
||||
CudnnPrefillMetadata.ChunkedContextMetadata \
|
||||
if self._use_cudnn_prefill else \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
if self.dcp_world_size > 1:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
|
||||
origin_context_lens=origin_context_lens,
|
||||
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
chunk_size=max_context_chunk,
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
chunked_context_metadata.seq_lens = chunk_seq_lens
|
||||
@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def reorg_kvcache(
|
||||
allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
cp_chunk_seq_lens_lst: list[int],
|
||||
origin_context_lens: list[int],
|
||||
cp_world_size: int,
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
reorg kvcache after cp local gather to tp layout for attn kernel.
|
||||
|
||||
Args:
|
||||
cp_chunk_seq_lens_lst: chunk context lengths under CP.
|
||||
origin_context_lens: origin full context lengths under CP.
|
||||
cp_world_size: CP size.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: equals to max_context_chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
k_pe_segments = []
|
||||
src_token_idx = 0
|
||||
max_seq_len_check = 0
|
||||
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
|
||||
origin_context_lens):
|
||||
chunk_context_len = chunk_size
|
||||
if cp_chunk_seq_len != 0:
|
||||
chunk_context_len = min(
|
||||
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
|
||||
cp_target_rank = (chunk_context_len - 1) % cp_world_size
|
||||
cur_seq_len = 0
|
||||
for rank in range(cp_world_size):
|
||||
if rank > cp_target_rank and cp_chunk_seq_len:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
|
||||
else:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len
|
||||
if real_cp_chunk_seq_len:
|
||||
kv_c_segment = allgatered_kv_c_normed[rank * toks +
|
||||
src_token_idx:rank *
|
||||
toks + src_token_idx +
|
||||
real_cp_chunk_seq_len]
|
||||
k_pe_segment = allgatered_k_pe[rank * toks +
|
||||
src_token_idx:rank * toks +
|
||||
src_token_idx +
|
||||
real_cp_chunk_seq_len]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += real_cp_chunk_seq_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += cp_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||
assert reorganized_k_pe.shape[0] == sum_seq_len
|
||||
assert max_seq_len_check == max_seq_len
|
||||
return reorganized_kv_c_normed, reorganized_k_pe
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
self.dcp_world_size: Optional[int] = None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _context_parallel_compute_prefill_context(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
assert k_scale is None, "DCP not support sacled kvcache now."
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.origin_context_lens is not None
|
||||
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
ops.cp_gather_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
# workspace
|
||||
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
|
||||
# |<- use for loca_gather ->|<--------- use for allgather -------->|
|
||||
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
|
||||
assert allgather_offset * (dcp_world_size +
|
||||
1) == workspace.shape[0]
|
||||
assert toks <= allgather_offset
|
||||
local_gathered_kvcache = workspace[:toks]
|
||||
cur_allgather_workspace = workspace[
|
||||
allgather_offset:allgather_offset * (1 + dcp_world_size)]
|
||||
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
|
||||
cur_allgather_kvcache = cur_allgather_workspace[:toks *
|
||||
dcp_world_size]
|
||||
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
|
||||
local_gathered_kvcache, dim=0))
|
||||
assert cur_allgather_kvcache.shape[
|
||||
-1] == self.kv_lora_rank + self.qk_rope_head_dim
|
||||
allgatered_kv_c_normed, allgatered_k_pe = \
|
||||
cur_allgather_kvcache.unsqueeze(
|
||||
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
kv_c_normed, k_pe = reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
|
||||
cp_chunk_seq_lens[i],
|
||||
origin_context_lens=prefill_metadata.chunked_context.
|
||||
origin_context_lens,
|
||||
cp_world_size=dcp_world_size,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
|
||||
[-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks)
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
|
||||
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
||||
prefill=prefill_metadata,
|
||||
chunk_idx=i,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = attn_output
|
||||
output_lse = attn_softmax_lse
|
||||
else:
|
||||
output_tmp = torch.empty_like(output)
|
||||
output_lse_tmp = torch.empty_like(output_lse)
|
||||
merge_attn_states(
|
||||
output=output_tmp,
|
||||
output_lse=output_lse_tmp,
|
||||
prefix_output=output,
|
||||
prefix_lse=output_lse,
|
||||
suffix_output=attn_output,
|
||||
suffix_lse=attn_softmax_lse,
|
||||
)
|
||||
output = output_tmp
|
||||
output_lse = output_lse_tmp
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size is not None
|
||||
|
||||
has_context = attn_metadata.prefill.chunked_context is not None
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
@ -1181,8 +1437,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
if self.dcp_world_size > 1:
|
||||
context_output, context_lse = \
|
||||
self._context_parallel_compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata,
|
||||
k_scale=None, dcp_world_size=self.dcp_world_size)
|
||||
else:
|
||||
context_output, context_lse = \
|
||||
self._compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
if self.dcp_world_size is None:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
layer._q_scale)
|
||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
if self.dcp_world_size > 1:
|
||||
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
|
||||
decode_q = torch.cat(decode_q, dim=-1)
|
||||
# decode_q do allgather in head dim.
|
||||
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
|
||||
|
||||
# call decode attn
|
||||
attn_out, lse = self._forward_decode(decode_q, kv_cache,
|
||||
attn_metadata, layer)
|
||||
|
||||
# recorect dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
assert lse is not None, (
|
||||
"For a mla backend want to enable"
|
||||
"DCP, it is mandatory that the corresponding decode attn"
|
||||
"kernel return the softmax lse.")
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
|
||||
# v_up projection
|
||||
output[:num_decode_tokens] = self._v_up_proj(attn_out)
|
||||
return output_padded
|
||||
|
||||
@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self._workspace.get_buf(),
|
||||
self.scale, self._num_kv_splits)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o
|
||||
|
||||
# TODO: Currently we leave it here only for backup in case something is
|
||||
# wrong with the new SM100 CUTLASS MLA kernel
|
||||
@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
# testing
|
||||
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
attn_metadata), None
|
||||
|
||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
attn_metadata), None
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
assert isinstance(q, torch.Tensor)
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, lse
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
B = q_nope.shape[0]
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, None
|
||||
|
||||
@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC):
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
|
||||
enable_kv_cache_events)
|
||||
use_eagle: bool, enable_kv_cache_events: bool,
|
||||
dcp_world_size: int):
|
||||
super().__init__(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
False,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, enable_kv_cache_events)
|
||||
enable_kv_cache_events: bool, dcp_world_size: int):
|
||||
super().__init__(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||
|
||||
@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, enable_kv_cache_events)
|
||||
enable_kv_cache_events: bool, dcp_world_size: int):
|
||||
super().__init__(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
|
||||
return KVCacheCoordinatorNoPrefixCache(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_kv_cache_events)
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, enable_kv_cache_events)
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
return HybridKVCacheCoordinator(kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size)
|
||||
|
||||
@ -91,6 +91,7 @@ class KVCacheManager:
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
dcp_world_size: int = 1,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
@ -109,12 +110,20 @@ class KVCacheManager:
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
# Note(hc): need revisit. When both DCP and any future
|
||||
# PCP are enabled, the block_size may need to be scaled
|
||||
# by a factor of dcp_size × pcp_size?
|
||||
self.block_size *= dcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=self.enable_caching,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
|
||||
@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
)
|
||||
|
||||
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size)
|
||||
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
|
||||
@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
self.dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): The scheduler’s block_size must be multiplied
|
||||
# by dcp_world_size, since block hashes are computed on the
|
||||
# original full token sequence at a granularity of
|
||||
# original_block_size × dcp_world_size.
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= self.dcp_world_size
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: dict[str, Request] = {}
|
||||
# Scheduling policy
|
||||
@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface):
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
"and chunked local attention groups"
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
block_size = kv_cache_spec.block_size
|
||||
if dcp_world_size > 1:
|
||||
block_size *= dcp_world_size
|
||||
max_num_blocks = max_length // block_size
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups")
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
"chunked local attention groups")
|
||||
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
|
||||
"eagle + chunked local attention.")
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (max_length //
|
||||
@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec,
|
||||
MambaSpec), ("MambaManager can only be used for mamba groups")
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
# Prefix caching is not supported for mamba now. Always return empty
|
||||
# list.
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
|
||||
@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
|
||||
"DCP not support sliding window."
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -50,6 +51,13 @@ class BlockTable:
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@ -89,13 +97,36 @@ class BlockTable:
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with a interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // virtual_block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
# Calcuate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
# Calcuate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
else:
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions // self.block_size)
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
@ -128,9 +159,19 @@ class MultiGroupBlockTable:
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_sizes: list[int]) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
|
||||
block_size),
|
||||
BlockTable(block_size, max_num_reqs,
|
||||
cdiv(max_model_len, block_size * dcp_world_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
model_config.is_multimodal_raw_input_only_model)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return
|
||||
|
||||
if self.reorder_batch_threshold is not None:
|
||||
if self.dcp_world_size > 1:
|
||||
assert self.reorder_batch_threshold == 1, \
|
||||
"DCP not support reorder_batch_threshold > 1 now."
|
||||
reorder_batch_to_split_decodes_and_prefills(
|
||||
self.input_batch,
|
||||
scheduler_output,
|
||||
@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
assert self.attn_groups[0][0].backend is FlashMLABackend, (
|
||||
"DCP only support flashmla now."
|
||||
"For a mla backend want to enable DCP, it is mandatory that the"
|
||||
"corresponding decode attn kernel return the softmax lse.")
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
|
||||
@ -616,7 +616,9 @@ def init_worker_distributed_environment(
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank, backend)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
@ -539,8 +539,10 @@ def init_worker_distributed_environment(
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
current_platform.dist_backend)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
parallel_config.decode_context_parallel_size)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user