mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
3c529fc994
commit
ac201a0eaf
@ -837,7 +837,7 @@ steps:
|
|||||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||||
|
|
||||||
- label: Pipeline Parallelism Test # 45min
|
- label: Pipeline + Context Parallelism Test # 45min
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
@ -851,6 +851,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
|
||||||
|
|
||||||
- label: LoRA TP Test (Distributed) # 17 min
|
- label: LoRA TP Test (Distributed) # 17 min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
|
|||||||
@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
|||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
|
||||||
torch::Tensor& cp_local_token_select_indices,
|
|
||||||
torch::Tensor& kv_cache,
|
|
||||||
torch::Tensor& slot_mapping,
|
|
||||||
const std::string& kv_cache_dtype,
|
|
||||||
torch::Tensor& scale);
|
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double scale, const std::string& kv_cache_dtype);
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
|
|||||||
@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
|
|||||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
|
||||||
__global__ void cp_fused_concat_and_cache_mla_kernel(
|
|
||||||
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
|
|
||||||
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
|
|
||||||
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
|
|
||||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
|
||||||
// + pe_dim)]
|
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
|
||||||
const int block_stride, //
|
|
||||||
const int entry_stride, //
|
|
||||||
const int kv_c_stride, //
|
|
||||||
const int k_pe_stride, //
|
|
||||||
const int kv_lora_rank, //
|
|
||||||
const int pe_dim, //
|
|
||||||
const int block_size, //
|
|
||||||
const float* scale //
|
|
||||||
) {
|
|
||||||
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
|
|
||||||
const int64_t slot_idx = slot_mapping[blockIdx.x];
|
|
||||||
// NOTE: slot_idx can be -1 if the token is padded
|
|
||||||
if (slot_idx < 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int64_t block_idx = slot_idx / block_size;
|
|
||||||
const int64_t block_offset = slot_idx % block_size;
|
|
||||||
|
|
||||||
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
|
||||||
int src_stride, int dst_stride, int size, int offset) {
|
|
||||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
|
||||||
const int64_t src_idx = token_idx * src_stride + i;
|
|
||||||
const int64_t dst_idx =
|
|
||||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
|
||||||
dst[dst_idx] = src[src_idx];
|
|
||||||
} else {
|
|
||||||
dst[dst_idx] =
|
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
|
||||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the data type of key and value tensors.
|
// KV_T is the data type of key and value tensors.
|
||||||
@ -554,20 +509,6 @@ void reshape_and_cache_flash(
|
|||||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||||
|
|
||||||
// KV_T is the data type of key and value tensors.
|
|
||||||
// CACHE_T is the stored data type of kv-cache.
|
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
|
||||||
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
|
||||||
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
|
||||||
<<<grid, block, 0, stream>>>( \
|
|
||||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
|
||||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
|
||||||
cp_local_token_select_indices.data_ptr<int64_t>(), \
|
|
||||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
|
||||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
|
||||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
|
||||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
|
||||||
|
|
||||||
void concat_and_cache_mla(
|
void concat_and_cache_mla(
|
||||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||||
@ -606,50 +547,6 @@ void concat_and_cache_mla(
|
|||||||
CALL_CONCAT_AND_CACHE_MLA);
|
CALL_CONCAT_AND_CACHE_MLA);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
|
|
||||||
// calls into one:
|
|
||||||
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
|
|
||||||
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
|
|
||||||
// concat_and_cache_mla.
|
|
||||||
void cp_fused_concat_and_cache_mla(
|
|
||||||
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
|
|
||||||
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
|
|
||||||
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
|
|
||||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
|
||||||
// pe_dim)]
|
|
||||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
|
||||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
|
||||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
|
||||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
|
||||||
// both include padding.
|
|
||||||
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
|
||||||
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
|
||||||
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
|
||||||
// before padding.
|
|
||||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
|
||||||
// number of tokens.
|
|
||||||
int num_tokens = slot_mapping.size(0);
|
|
||||||
int kv_lora_rank = kv_c.size(1);
|
|
||||||
int pe_dim = k_pe.size(1);
|
|
||||||
int block_size = kv_cache.size(1);
|
|
||||||
|
|
||||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
|
||||||
|
|
||||||
int kv_c_stride = kv_c.stride(0);
|
|
||||||
int k_pe_stride = k_pe.stride(0);
|
|
||||||
int block_stride = kv_cache.stride(0);
|
|
||||||
int entry_stride = kv_cache.stride(1);
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min(kv_lora_rank, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
|
||||||
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
|||||||
@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor scale) -> ()");
|
" Tensor scale) -> ()");
|
||||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||||
|
|
||||||
cache_ops.def(
|
|
||||||
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
|
||||||
" Tensor cp_local_token_select_indices,"
|
|
||||||
" Tensor! kv_cache,"
|
|
||||||
" Tensor slot_mapping,"
|
|
||||||
" str kv_cache_dtype,"
|
|
||||||
" Tensor scale) -> ()");
|
|
||||||
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
|
|
||||||
&cp_fused_concat_and_cache_mla);
|
|
||||||
|
|
||||||
// Convert the key and value cache to fp8 data type.
|
// Convert the key and value cache to fp8 data type.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||||
|
|||||||
263
tests/distributed/test_context_parallel.py
Normal file
263
tests/distributed/test_context_parallel.py
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||||
|
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
|
||||||
|
important to set the distributed backend to "mp" to avoid Ray scheduling
|
||||||
|
all workers in a node other than the head node, which can cause the test
|
||||||
|
to fail.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, NamedTuple, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import RunnerOption
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
|
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||||
|
|
||||||
|
logger = init_logger("test_context_parallel")
|
||||||
|
|
||||||
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelSetup(NamedTuple):
|
||||||
|
tp_size: int
|
||||||
|
pp_size: int
|
||||||
|
dcp_size: int
|
||||||
|
eager_mode: bool
|
||||||
|
chunked_prefill: bool
|
||||||
|
|
||||||
|
|
||||||
|
class CPTestOptions(NamedTuple):
|
||||||
|
multi_node_only: bool
|
||||||
|
load_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CPTestSettings:
|
||||||
|
parallel_setups: list[ParallelSetup]
|
||||||
|
# NOTE: the length of distributed_backends and
|
||||||
|
# vllm_major_versions should be the same, and they
|
||||||
|
# are first zipped together to iterate over all
|
||||||
|
# test settings.
|
||||||
|
distributed_backends: list[str]
|
||||||
|
# vllm major version: "0" for V0, "1" for V1
|
||||||
|
vllm_major_versions: list[str]
|
||||||
|
runner: RunnerOption
|
||||||
|
test_options: CPTestOptions
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length mismatch: distributed_backends "
|
||||||
|
f"({len(self.distributed_backends)}) != "
|
||||||
|
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detailed(
|
||||||
|
*,
|
||||||
|
tp_base: int = 4,
|
||||||
|
pp_base: int = 1,
|
||||||
|
dcp_base: int = 1,
|
||||||
|
multi_node_only: bool = False,
|
||||||
|
runner: RunnerOption = "auto",
|
||||||
|
load_format: Optional[str] = None,
|
||||||
|
):
|
||||||
|
parallel_setups = []
|
||||||
|
for eager_mode_val in [False]:
|
||||||
|
for pp_multiplier in [1]:
|
||||||
|
for dcp_multiplier in [2, 4]:
|
||||||
|
for chunked_prefill_val in [True]:
|
||||||
|
parallel_setups.append(
|
||||||
|
ParallelSetup(tp_size=tp_base,
|
||||||
|
pp_size=pp_multiplier * pp_base,
|
||||||
|
dcp_size=dcp_multiplier * dcp_base,
|
||||||
|
eager_mode=eager_mode_val,
|
||||||
|
chunked_prefill=chunked_prefill_val))
|
||||||
|
return CPTestSettings(
|
||||||
|
parallel_setups=parallel_setups,
|
||||||
|
distributed_backends=["mp"],
|
||||||
|
vllm_major_versions=["1"],
|
||||||
|
runner=runner,
|
||||||
|
test_options=CPTestOptions(multi_node_only=multi_node_only,
|
||||||
|
load_format=load_format),
|
||||||
|
)
|
||||||
|
|
||||||
|
def iter_params(self, model_id: str):
|
||||||
|
opts = self.test_options
|
||||||
|
|
||||||
|
for parallel_setup in self.parallel_setups:
|
||||||
|
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||||
|
self.vllm_major_versions):
|
||||||
|
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||||
|
self.runner, opts)
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_cp_with_tp(
|
||||||
|
model_id: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
vllm_major_version: str,
|
||||||
|
runner: RunnerOption,
|
||||||
|
test_options: CPTestOptions,
|
||||||
|
num_gpus_available: int,
|
||||||
|
*,
|
||||||
|
method: Literal["generate"],
|
||||||
|
is_multimodal: bool,
|
||||||
|
):
|
||||||
|
(
|
||||||
|
tp_size,
|
||||||
|
pp_size,
|
||||||
|
dcp_size,
|
||||||
|
eager_mode,
|
||||||
|
chunked_prefill,
|
||||||
|
) = parallel_setup
|
||||||
|
|
||||||
|
multi_node_only, load_format = test_options
|
||||||
|
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
|
trust_remote_code = model_info.trust_remote_code
|
||||||
|
tokenizer_mode = model_info.tokenizer_mode
|
||||||
|
hf_overrides = model_info.hf_overrides
|
||||||
|
|
||||||
|
if load_format == "dummy":
|
||||||
|
# Avoid OOM
|
||||||
|
text_overrides = {
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"hidden_size": 512,
|
||||||
|
"intermediate_size": 800,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_key_value_heads": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_multimodal:
|
||||||
|
hf_overrides.update({"text_config": text_overrides})
|
||||||
|
else:
|
||||||
|
hf_overrides.update(text_overrides)
|
||||||
|
else:
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
if num_gpus_available < tp_size * pp_size:
|
||||||
|
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||||
|
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||||
|
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||||
|
"multiprocessing distributed backend")
|
||||||
|
if multi_node_only and not VLLM_MULTI_NODE:
|
||||||
|
pytest.skip("Not in multi-node setting")
|
||||||
|
|
||||||
|
common_args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"8",
|
||||||
|
]
|
||||||
|
if chunked_prefill:
|
||||||
|
common_args.append("--enable-chunked-prefill")
|
||||||
|
if eager_mode:
|
||||||
|
common_args.append("--enforce-eager")
|
||||||
|
if runner != "auto":
|
||||||
|
common_args.extend(["--runner", runner])
|
||||||
|
if trust_remote_code:
|
||||||
|
common_args.append("--trust-remote-code")
|
||||||
|
if tokenizer_mode:
|
||||||
|
common_args.extend(["--tokenizer-mode", tokenizer_mode])
|
||||||
|
if load_format:
|
||||||
|
common_args.extend(["--load-format", load_format])
|
||||||
|
if hf_overrides:
|
||||||
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
|
||||||
|
cp_env = tp_env = {
|
||||||
|
"VLLM_USE_V1":
|
||||||
|
vllm_major_version, # Note(hc): DCP only support V1 engine only
|
||||||
|
}
|
||||||
|
|
||||||
|
cp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
str(pp_size),
|
||||||
|
"--decode-context-parallel-size",
|
||||||
|
str(dcp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
distributed_backend,
|
||||||
|
]
|
||||||
|
|
||||||
|
tp_args = [
|
||||||
|
*common_args,
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
str(tp_size),
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
str(pp_size),
|
||||||
|
"--distributed-executor-backend",
|
||||||
|
distributed_backend,
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
compare_two_settings(model_id,
|
||||||
|
cp_args,
|
||||||
|
tp_args,
|
||||||
|
cp_env,
|
||||||
|
tp_env,
|
||||||
|
method=method,
|
||||||
|
max_wait_seconds=720)
|
||||||
|
except Exception:
|
||||||
|
testing_ray_compiled_graph = cp_env is not None
|
||||||
|
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||||
|
# Ray Compiled Graph tests are flaky for V0,
|
||||||
|
# so we don't want to fail the test
|
||||||
|
logger.exception("Ray Compiled Graph tests failed")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
CP_TEXT_GENERATION_MODELS = {
|
||||||
|
# [MLA attention only]
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(),
|
||||||
|
}
|
||||||
|
|
||||||
|
CP_TEST_MODELS = [
|
||||||
|
# TODO support other models
|
||||||
|
# [LANGUAGE GENERATION]
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||||
|
"runner", "test_options"),
|
||||||
|
[
|
||||||
|
params for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
|
||||||
|
for params in settings.iter_params(model_id)
|
||||||
|
if model_id in CP_TEST_MODELS
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@create_new_process_for_each_test()
|
||||||
|
def test_cp_generation(
|
||||||
|
model_id: str,
|
||||||
|
parallel_setup: ParallelSetup,
|
||||||
|
distributed_backend: str,
|
||||||
|
vllm_major_version: str,
|
||||||
|
runner: RunnerOption,
|
||||||
|
test_options: CPTestOptions,
|
||||||
|
num_gpus_available,
|
||||||
|
):
|
||||||
|
_compare_cp_with_tp(model_id,
|
||||||
|
parallel_setup,
|
||||||
|
distributed_backend,
|
||||||
|
vllm_major_version,
|
||||||
|
runner,
|
||||||
|
test_options,
|
||||||
|
num_gpus_available,
|
||||||
|
method="generate",
|
||||||
|
is_multimodal=False)
|
||||||
@ -1625,20 +1625,6 @@ def concat_and_cache_mla(
|
|||||||
scale)
|
scale)
|
||||||
|
|
||||||
|
|
||||||
def cp_fused_concat_and_cache_mla(
|
|
||||||
kv_c: torch.Tensor,
|
|
||||||
k_pe: torch.Tensor,
|
|
||||||
cp_local_token_select_indices: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
scale: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
|
|
||||||
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
|
|
||||||
kv_cache_dtype, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def copy_blocks(key_caches: list[torch.Tensor],
|
def copy_blocks(key_caches: list[torch.Tensor],
|
||||||
value_caches: list[torch.Tensor],
|
value_caches: list[torch.Tensor],
|
||||||
block_mapping: torch.Tensor) -> None:
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
|||||||
139
vllm/attention/ops/common.py
Normal file
139
vllm/attention/ops/common.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import GroupCoordinator
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
|
||||||
|
vlse_ptr, outputs_stride_B, outputs_stride_H,
|
||||||
|
outputs_stride_D, lses_stride_N, lses_stride_B,
|
||||||
|
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
|
||||||
|
N_ROUNDED: tl.constexpr):
|
||||||
|
"""
|
||||||
|
Apply the all-gathered lses to correct each local rank's attention
|
||||||
|
output. we still need perform a cross-rank reduction to obtain the
|
||||||
|
final attention output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: [ B, H, D ]
|
||||||
|
lses : [ N, B, H ]
|
||||||
|
cp, batch, q_heads, v_head_dim
|
||||||
|
Return:
|
||||||
|
output: [ B, H, D ]
|
||||||
|
lse : [ B, H ]
|
||||||
|
"""
|
||||||
|
batch_idx = tl.program_id(axis=0).to(tl.int64)
|
||||||
|
head_idx = tl.program_id(axis=1).to(tl.int64)
|
||||||
|
d_offsets = tl.arange(0, HEAD_DIM)
|
||||||
|
num_n_offsets = tl.arange(0, N_ROUNDED)
|
||||||
|
|
||||||
|
# shape = [N]
|
||||||
|
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
|
||||||
|
lses_stride_B + head_idx * lses_stride_H
|
||||||
|
|
||||||
|
# calc final lse
|
||||||
|
lse = tl.load(lses_ptr + lse_offsets)
|
||||||
|
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
|
||||||
|
lse_max = tl.max(lse, axis=0)
|
||||||
|
lse -= lse_max
|
||||||
|
lse_exp = tl.exp(lse)
|
||||||
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
lse = tl.log(lse_acc)
|
||||||
|
lse += lse_max
|
||||||
|
|
||||||
|
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
|
||||||
|
tl.store(vlse_ptr + lse_offsets, lse)
|
||||||
|
|
||||||
|
# shape = [D]
|
||||||
|
output_offsets = batch_idx * outputs_stride_B + \
|
||||||
|
head_idx * outputs_stride_H + \
|
||||||
|
d_offsets * outputs_stride_D
|
||||||
|
|
||||||
|
# correct output
|
||||||
|
lse_offset = lse_idx * lses_stride_N + batch_idx * \
|
||||||
|
lses_stride_B + head_idx * lses_stride_H
|
||||||
|
lse_tmp = tl.load(lses_ptr + lse_offset)
|
||||||
|
lse_finally = lse_tmp - lse
|
||||||
|
lse_finally = tl.where(
|
||||||
|
(lse_finally != lse_finally) | (lse_finally == float('inf')),
|
||||||
|
-float('inf'), lse_finally)
|
||||||
|
factor = tl.exp(lse_finally)
|
||||||
|
output = tl.load(outputs_ptr + output_offsets)
|
||||||
|
output = output * factor
|
||||||
|
|
||||||
|
tl.store(new_output_ptr + output_offsets, output)
|
||||||
|
|
||||||
|
|
||||||
|
class CPTritonContext:
|
||||||
|
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.inner_kernel = None
|
||||||
|
|
||||||
|
def call_kernel(self, kernel, grid, *regular_args, **const_args):
|
||||||
|
if self.inner_kernel is None:
|
||||||
|
self.inner_kernel = kernel[grid](*regular_args, **const_args)
|
||||||
|
else:
|
||||||
|
self.inner_kernel[grid](*regular_args)
|
||||||
|
|
||||||
|
|
||||||
|
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
|
||||||
|
ctx: CPTritonContext):
|
||||||
|
"""
|
||||||
|
Apply the all-gathered lses to correct each local rank's attention
|
||||||
|
output. we still need perform a cross-rank reduction to obtain the
|
||||||
|
final attention output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: [ B, H, D ]
|
||||||
|
lses : [ N, B, H ]
|
||||||
|
Return:
|
||||||
|
output: [ B, H, D ]
|
||||||
|
lse : [ B, H ]
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
lse = torch.empty_like(lses[0])
|
||||||
|
|
||||||
|
grid = (out.shape[0], out.shape[1], 1)
|
||||||
|
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
|
||||||
|
cp_rank)
|
||||||
|
const_args = {
|
||||||
|
"HEAD_DIM": out.shape[-1],
|
||||||
|
"N_ROUNDED": lses.shape[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
|
||||||
|
**const_args)
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||||
|
cp_attn_lse: torch.Tensor,
|
||||||
|
cp_group: GroupCoordinator,
|
||||||
|
ctx: CPTritonContext = None):
|
||||||
|
"""
|
||||||
|
cp_attn_out: [ B, H, D ]
|
||||||
|
cp_attn_lse: [ B, H ]
|
||||||
|
"""
|
||||||
|
if cp_group.world_size == 1:
|
||||||
|
return cp_attn_out
|
||||||
|
|
||||||
|
if ctx is None:
|
||||||
|
ctx = CPTritonContext()
|
||||||
|
|
||||||
|
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
|
||||||
|
dtype=cp_attn_lse.dtype,
|
||||||
|
device=cp_attn_lse.device)
|
||||||
|
|
||||||
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
|
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
|
assert out.is_contiguous()
|
||||||
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
return out
|
||||||
@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
|
|||||||
descale_q,
|
descale_q,
|
||||||
descale_k,
|
descale_k,
|
||||||
)
|
)
|
||||||
return out, softmax_lse
|
|
||||||
|
# Note(hc): need revisit when we support DCP with decode query_len > 1.
|
||||||
|
return out.squeeze(1), softmax_lse.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
@ -170,6 +170,11 @@ class ParallelConfig:
|
|||||||
Set to be private as it's not intended to be configured by users.
|
Set to be private as it's not intended to be configured by users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
decode_context_parallel_size: int = 1
|
||||||
|
"""Number of decode context parallel groups, because the world size does
|
||||||
|
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||||
|
needs to be divisible by dcp_size."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size_across_dp(self) -> int:
|
def world_size_across_dp(self) -> int:
|
||||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||||
|
|||||||
@ -904,6 +904,18 @@ def get_tensor_model_parallel_group():
|
|||||||
return get_tp_group()
|
return get_tp_group()
|
||||||
|
|
||||||
|
|
||||||
|
_DCP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_dcp_group() -> GroupCoordinator:
|
||||||
|
assert _DCP is not None, (
|
||||||
|
"decode context model parallel group is not initialized")
|
||||||
|
return _DCP
|
||||||
|
|
||||||
|
|
||||||
|
# kept for backward compatibility
|
||||||
|
get_context_model_parallel_group = get_dcp_group
|
||||||
|
|
||||||
_PP: Optional[GroupCoordinator] = None
|
_PP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
_DP: Optional[GroupCoordinator] = None
|
_DP: Optional[GroupCoordinator] = None
|
||||||
@ -1034,6 +1046,7 @@ def init_distributed_environment(
|
|||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
decode_context_model_parallel_size: Optional[int] = 1,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1098,6 +1111,23 @@ def initialize_model_parallel(
|
|||||||
use_message_queue_broadcaster=True,
|
use_message_queue_broadcaster=True,
|
||||||
group_name="tp")
|
group_name="tp")
|
||||||
|
|
||||||
|
# Build the DCP model-parallel groups.
|
||||||
|
global _DCP
|
||||||
|
assert _DCP is None, (
|
||||||
|
"decode context model parallel group is already initialized")
|
||||||
|
# Note(hc): In the current implementation of decode context parallel,
|
||||||
|
# dcp_size must not exceed tp_size, because the world size does not
|
||||||
|
# change by DCP, it simply reuse the GPUs of TP group, and split one
|
||||||
|
# TP group into tp_size//dcp_size DCP groups.
|
||||||
|
group_ranks = all_ranks.reshape(
|
||||||
|
-1, decode_context_model_parallel_size).unbind(0)
|
||||||
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
_DCP = init_model_parallel_group(group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
use_message_queue_broadcaster=True,
|
||||||
|
group_name="dcp")
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
global _PP
|
global _PP
|
||||||
assert _PP is None, (
|
assert _PP is None, (
|
||||||
@ -1141,6 +1171,7 @@ def initialize_model_parallel(
|
|||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
tensor_model_parallel_size: int,
|
tensor_model_parallel_size: int,
|
||||||
pipeline_model_parallel_size: int,
|
pipeline_model_parallel_size: int,
|
||||||
|
decode_context_model_parallel_size: Optional[int] = 1,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Helper to initialize model parallel groups if they are not initialized,
|
"""Helper to initialize model parallel groups if they are not initialized,
|
||||||
@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized(
|
|||||||
get_world_group().device_group)
|
get_world_group().device_group)
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(tensor_model_parallel_size,
|
initialize_model_parallel(tensor_model_parallel_size,
|
||||||
pipeline_model_parallel_size, backend)
|
pipeline_model_parallel_size,
|
||||||
|
decode_context_model_parallel_size, backend)
|
||||||
return
|
return
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank():
|
|||||||
return get_tp_group().rank_in_group
|
return get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
|
||||||
|
def get_decode_context_model_parallel_world_size():
|
||||||
|
"""Return world size for the decode context model parallel group."""
|
||||||
|
return get_dcp_group().world_size
|
||||||
|
|
||||||
|
|
||||||
|
def get_decode_context_model_parallel_rank():
|
||||||
|
"""Return my rank for the decode context model parallel group."""
|
||||||
|
return get_dcp_group().rank_in_group
|
||||||
|
|
||||||
|
|
||||||
def get_node_count() -> int:
|
def get_node_count() -> int:
|
||||||
"""Return the total number of nodes in the distributed environment. """
|
"""Return the total number of nodes in the distributed environment. """
|
||||||
assert _NODE_COUNT is not None, (
|
assert _NODE_COUNT is not None, (
|
||||||
@ -1246,6 +1288,11 @@ def destroy_model_parallel():
|
|||||||
_PP.destroy()
|
_PP.destroy()
|
||||||
_PP = None
|
_PP = None
|
||||||
|
|
||||||
|
global _DCP
|
||||||
|
if _DCP:
|
||||||
|
_DCP.destroy()
|
||||||
|
_DCP = None
|
||||||
|
|
||||||
global _DP
|
global _DP
|
||||||
if _DP:
|
if _DP:
|
||||||
_DP.destroy()
|
_DP.destroy()
|
||||||
|
|||||||
@ -306,6 +306,8 @@ class EngineArgs:
|
|||||||
# number of P/D disaggregation (or other disaggregation) workers
|
# number of P/D disaggregation (or other disaggregation) workers
|
||||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
|
decode_context_parallel_size: int = \
|
||||||
|
ParallelConfig.decode_context_parallel_size
|
||||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||||
data_parallel_rank: Optional[int] = None
|
data_parallel_rank: Optional[int] = None
|
||||||
data_parallel_start_rank: Optional[int] = None
|
data_parallel_start_rank: Optional[int] = None
|
||||||
@ -636,6 +638,9 @@ class EngineArgs:
|
|||||||
**parallel_kwargs["pipeline_parallel_size"])
|
**parallel_kwargs["pipeline_parallel_size"])
|
||||||
parallel_group.add_argument("--tensor-parallel-size", "-tp",
|
parallel_group.add_argument("--tensor-parallel-size", "-tp",
|
||||||
**parallel_kwargs["tensor_parallel_size"])
|
**parallel_kwargs["tensor_parallel_size"])
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--decode-context-parallel-size", "-dcp",
|
||||||
|
**parallel_kwargs["decode_context_parallel_size"])
|
||||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||||
**parallel_kwargs["data_parallel_size"])
|
**parallel_kwargs["data_parallel_size"])
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
@ -1156,6 +1161,17 @@ class EngineArgs:
|
|||||||
# global layers in interleaved sliding window models.
|
# global layers in interleaved sliding window models.
|
||||||
sliding_window = model_config.get_sliding_window()
|
sliding_window = model_config.get_sliding_window()
|
||||||
|
|
||||||
|
# Note(hc): In the current implementation of decode context
|
||||||
|
# parallel(DCP), tp_size needs to be divisible by dcp_size,
|
||||||
|
# because the world size does not change by dcp, it simply
|
||||||
|
# reuse the GPUs of TP group, and split one TP group into
|
||||||
|
# tp_size//dcp_size DCP groups.
|
||||||
|
assert self.tensor_parallel_size % self.decode_context_parallel_size \
|
||||||
|
== 0, (
|
||||||
|
f"tp_size={self.tensor_parallel_size} must be divisible by"
|
||||||
|
f"dcp_size={self.decode_context_parallel_size}."
|
||||||
|
)
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||||
@ -1306,6 +1322,7 @@ class EngineArgs:
|
|||||||
distributed_executor_backend=self.distributed_executor_backend,
|
distributed_executor_backend=self.distributed_executor_backend,
|
||||||
worker_cls=self.worker_cls,
|
worker_cls=self.worker_cls,
|
||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
|
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_config = self.create_speculative_config(
|
speculative_config = self.create_speculative_config(
|
||||||
|
|||||||
@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
from vllm.attention.backends.utils import get_mla_dims
|
from vllm.attention.backends.utils import get_mla_dims
|
||||||
|
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.parallel_state import is_global_first_rank
|
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
|
|
||||||
|
# for mla DCP
|
||||||
|
cp_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||||
|
origin_context_lens: Optional[list[int]] = None
|
||||||
|
cp_cu_seq_lens: Optional[torch.Tensor] = None
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||||
|
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
parallel_config)
|
parallel_config)
|
||||||
self.mla_dims = get_mla_dims(self.model_config)
|
self.mla_dims = get_mla_dims(self.model_config)
|
||||||
self.aot_schedule = current_platform.is_cuda()
|
self.aot_schedule = current_platform.is_cuda()
|
||||||
|
try:
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
|
||||||
# Dont try to access the runner on AMD
|
# Dont try to access the runner on AMD
|
||||||
if self.aot_schedule:
|
if self.aot_schedule:
|
||||||
@ -465,6 +480,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
128 * 1024)
|
128 * 1024)
|
||||||
assert self.chunked_prefill_workspace_size >= \
|
assert self.chunked_prefill_workspace_size >= \
|
||||||
scheduler_config.max_num_seqs * cache_config.block_size
|
scheduler_config.max_num_seqs * cache_config.block_size
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||||
|
# an additional kvcache allgather across the DCP group is therefore
|
||||||
|
# required, so the workspace has to be enlarged by 1/DCP relative
|
||||||
|
# to the original TP allocation.
|
||||||
|
assert self.chunked_prefill_workspace_size % \
|
||||||
|
self.dcp_world_size == 0
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size +
|
||||||
|
self.chunked_prefill_workspace_size // self.dcp_world_size,
|
||||||
|
self.model_config.get_head_size()),
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.chunked_prefill_workspace = torch.empty(
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
(self.chunked_prefill_workspace_size,
|
(self.chunked_prefill_workspace_size,
|
||||||
self.model_config.get_head_size()),
|
self.model_config.get_head_size()),
|
||||||
@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
split_decodes_and_prefills(common_attn_metadata,
|
split_decodes_and_prefills(common_attn_metadata,
|
||||||
decode_threshold=self.reorder_batch_threshold)
|
decode_threshold=self.reorder_batch_threshold)
|
||||||
|
|
||||||
|
# Note(hc): update seq_lens of decode reqs under DCP.
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
|
||||||
|
// self.dcp_world_size + (self.dcp_rank <= \
|
||||||
|
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
|
||||||
|
|
||||||
assert num_decodes + num_prefills == num_reqs
|
assert num_decodes + num_prefills == num_reqs
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||||
|
|
||||||
@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
reqs_start = num_decodes # prefill_start
|
reqs_start = num_decodes # prefill_start
|
||||||
|
|
||||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||||
|
# Note(hc): The context lengths in the perspective of dcp rank0.
|
||||||
|
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
|
||||||
|
self.dcp_world_size).int()
|
||||||
|
origin_context_lens = context_lens_cpu.tolist()
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
@ -691,14 +731,60 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
out=cu_seq_lens_cpu[:, 1:],
|
out=cu_seq_lens_cpu[:, 1:],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
# Note(hc): The above max_context_chunk already enforces
|
||||||
|
# block_size alignment, DCP just need the block_size can
|
||||||
|
# be divisible by dcp_world_size, because DCP use
|
||||||
|
# cp_gather_cache which not require `cp_chunk_starts`
|
||||||
|
# aligned to page_size.
|
||||||
|
assert max_context_chunk % self.dcp_world_size == 0
|
||||||
|
cp_max_context_chunk = max_context_chunk // \
|
||||||
|
self.dcp_world_size
|
||||||
|
cp_chunk_starts = \
|
||||||
|
torch.arange(num_chunks, dtype=torch.int32) \
|
||||||
|
.unsqueeze(1).expand(-1, num_prefills) \
|
||||||
|
* cp_max_context_chunk
|
||||||
|
cp_chunk_ends = torch.min(
|
||||||
|
cp_context_lens_cpu.unsqueeze(0),
|
||||||
|
cp_chunk_starts + cp_max_context_chunk)
|
||||||
|
cp_chunk_seq_lens = (cp_chunk_ends -
|
||||||
|
cp_chunk_starts).clamp(min=0)
|
||||||
|
|
||||||
|
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||||
|
num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True)
|
||||||
|
torch.cumsum(cp_chunk_seq_lens,
|
||||||
|
dim=1,
|
||||||
|
out=cp_cu_seq_lens_cpu[:, 1:],
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
chunked_context_metadata_cls = \
|
chunked_context_metadata_cls = \
|
||||||
CudnnPrefillMetadata.ChunkedContextMetadata \
|
CudnnPrefillMetadata.ChunkedContextMetadata \
|
||||||
if self._use_cudnn_prefill else \
|
if self._use_cudnn_prefill else \
|
||||||
MLACommonPrefillMetadata.ChunkedContextMetadata
|
MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
chunked_context_metadata = \
|
chunked_context_metadata = \
|
||||||
chunked_context_metadata_cls(
|
chunked_context_metadata_cls(
|
||||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
cu_seq_lens=cu_seq_lens_cpu \
|
||||||
|
.to(device, non_blocking=True),
|
||||||
|
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||||
|
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
seq_lens=chunk_seq_lens,
|
||||||
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
|
||||||
|
origin_context_lens=origin_context_lens,
|
||||||
|
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
|
||||||
|
.to(device, non_blocking=True),
|
||||||
|
chunk_size=max_context_chunk,
|
||||||
|
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunked_context_metadata = \
|
||||||
|
chunked_context_metadata_cls(
|
||||||
|
cu_seq_lens=cu_seq_lens_cpu \
|
||||||
|
.to(device, non_blocking=True),
|
||||||
starts=chunk_starts.to(device, non_blocking=True),
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def reorg_kvcache(
|
||||||
|
allgatered_kv_c_normed: torch.Tensor,
|
||||||
|
allgatered_k_pe: torch.Tensor,
|
||||||
|
cp_chunk_seq_lens_lst: list[int],
|
||||||
|
origin_context_lens: list[int],
|
||||||
|
cp_world_size: int,
|
||||||
|
sum_seq_len: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_idx: int,
|
||||||
|
toks: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
reorg kvcache after cp local gather to tp layout for attn kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cp_chunk_seq_lens_lst: chunk context lengths under CP.
|
||||||
|
origin_context_lens: origin full context lengths under CP.
|
||||||
|
cp_world_size: CP size.
|
||||||
|
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||||
|
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||||
|
chunk_size: equals to max_context_chunk from
|
||||||
|
chunked_context_metadata building.
|
||||||
|
chunk_idx: chunk idx of chunked_prefill.
|
||||||
|
toks: the number of tokens for local gather cache.
|
||||||
|
"""
|
||||||
|
kv_c_segments = []
|
||||||
|
k_pe_segments = []
|
||||||
|
src_token_idx = 0
|
||||||
|
max_seq_len_check = 0
|
||||||
|
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
|
||||||
|
origin_context_lens):
|
||||||
|
chunk_context_len = chunk_size
|
||||||
|
if cp_chunk_seq_len != 0:
|
||||||
|
chunk_context_len = min(
|
||||||
|
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
|
||||||
|
cp_target_rank = (chunk_context_len - 1) % cp_world_size
|
||||||
|
cur_seq_len = 0
|
||||||
|
for rank in range(cp_world_size):
|
||||||
|
if rank > cp_target_rank and cp_chunk_seq_len:
|
||||||
|
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
|
||||||
|
else:
|
||||||
|
real_cp_chunk_seq_len = cp_chunk_seq_len
|
||||||
|
if real_cp_chunk_seq_len:
|
||||||
|
kv_c_segment = allgatered_kv_c_normed[rank * toks +
|
||||||
|
src_token_idx:rank *
|
||||||
|
toks + src_token_idx +
|
||||||
|
real_cp_chunk_seq_len]
|
||||||
|
k_pe_segment = allgatered_k_pe[rank * toks +
|
||||||
|
src_token_idx:rank * toks +
|
||||||
|
src_token_idx +
|
||||||
|
real_cp_chunk_seq_len]
|
||||||
|
kv_c_segments.append(kv_c_segment)
|
||||||
|
k_pe_segments.append(k_pe_segment)
|
||||||
|
cur_seq_len += real_cp_chunk_seq_len
|
||||||
|
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||||
|
src_token_idx += cp_chunk_seq_len
|
||||||
|
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||||
|
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||||
|
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||||
|
assert reorganized_k_pe.shape[0] == sum_seq_len
|
||||||
|
assert max_seq_len_check == max_seq_len
|
||||||
|
return reorganized_kv_c_normed, reorganized_k_pe
|
||||||
|
|
||||||
|
|
||||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self.vllm_flash_attn_version == 3
|
self.vllm_flash_attn_version == 3
|
||||||
and current_platform.get_device_capability()[0] == 9)
|
and current_platform.get_device_capability()[0] == 9)
|
||||||
|
|
||||||
|
self.dcp_world_size: Optional[int] = None
|
||||||
|
|
||||||
def _flash_attn_varlen_diff_headdims(self,
|
def _flash_attn_varlen_diff_headdims(self,
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
return output, output_lse
|
return output, output_lse
|
||||||
|
|
||||||
|
def _context_parallel_compute_prefill_context(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
|
attn_metadata: MLACommonMetadata,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
dcp_world_size: int,
|
||||||
|
):
|
||||||
|
assert k_scale is None, "DCP not support sacled kvcache now."
|
||||||
|
assert attn_metadata.prefill is not None
|
||||||
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
assert prefill_metadata.chunked_context is not None
|
||||||
|
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
|
||||||
|
assert prefill_metadata.chunked_context.origin_context_lens is not None
|
||||||
|
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
|
||||||
|
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||||
|
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||||
|
|
||||||
|
output = None
|
||||||
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
|
workspace = prefill_metadata.chunked_context.workspace
|
||||||
|
|
||||||
|
for i in range(iters):
|
||||||
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
ops.cp_gather_cache(
|
||||||
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
|
dst=workspace,
|
||||||
|
block_table=prefill_metadata.block_table,
|
||||||
|
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
|
||||||
|
batch_size=attn_metadata.num_prefills,
|
||||||
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
|
)
|
||||||
|
# workspace
|
||||||
|
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
|
||||||
|
# |<- use for loca_gather ->|<--------- use for allgather -------->|
|
||||||
|
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
|
||||||
|
assert allgather_offset * (dcp_world_size +
|
||||||
|
1) == workspace.shape[0]
|
||||||
|
assert toks <= allgather_offset
|
||||||
|
local_gathered_kvcache = workspace[:toks]
|
||||||
|
cur_allgather_workspace = workspace[
|
||||||
|
allgather_offset:allgather_offset * (1 + dcp_world_size)]
|
||||||
|
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
|
||||||
|
cur_allgather_kvcache = cur_allgather_workspace[:toks *
|
||||||
|
dcp_world_size]
|
||||||
|
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
|
||||||
|
local_gathered_kvcache, dim=0))
|
||||||
|
assert cur_allgather_kvcache.shape[
|
||||||
|
-1] == self.kv_lora_rank + self.qk_rope_head_dim
|
||||||
|
allgatered_kv_c_normed, allgatered_k_pe = \
|
||||||
|
cur_allgather_kvcache.unsqueeze(
|
||||||
|
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
|
kv_c_normed, k_pe = reorg_kvcache(
|
||||||
|
allgatered_kv_c_normed,
|
||||||
|
allgatered_k_pe,
|
||||||
|
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
|
||||||
|
cp_chunk_seq_lens[i],
|
||||||
|
origin_context_lens=prefill_metadata.chunked_context.
|
||||||
|
origin_context_lens,
|
||||||
|
cp_world_size=dcp_world_size,
|
||||||
|
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
|
||||||
|
[-1],
|
||||||
|
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||||
|
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||||
|
chunk_idx=i,
|
||||||
|
toks=toks)
|
||||||
|
|
||||||
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
k_nope, v = kv_nope\
|
||||||
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
||||||
|
prefill=prefill_metadata,
|
||||||
|
chunk_idx=i,
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output is None:
|
||||||
|
output = attn_output
|
||||||
|
output_lse = attn_softmax_lse
|
||||||
|
else:
|
||||||
|
output_tmp = torch.empty_like(output)
|
||||||
|
output_lse_tmp = torch.empty_like(output_lse)
|
||||||
|
merge_attn_states(
|
||||||
|
output=output_tmp,
|
||||||
|
output_lse=output_lse_tmp,
|
||||||
|
prefix_output=output,
|
||||||
|
prefix_lse=output_lse,
|
||||||
|
suffix_output=attn_output,
|
||||||
|
suffix_lse=attn_softmax_lse,
|
||||||
|
)
|
||||||
|
output = output_tmp
|
||||||
|
output_lse = output_lse_tmp
|
||||||
|
|
||||||
|
return output, output_lse
|
||||||
|
|
||||||
def _forward_prefill(
|
def _forward_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
|
assert self.dcp_world_size is not None
|
||||||
|
|
||||||
has_context = attn_metadata.prefill.chunked_context is not None
|
has_context = attn_metadata.prefill.chunked_context is not None
|
||||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||||
@ -1181,7 +1437,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
if has_context:
|
if has_context:
|
||||||
suffix_output, suffix_lse = output
|
suffix_output, suffix_lse = output
|
||||||
context_output, context_lse = self._compute_prefill_context( \
|
if self.dcp_world_size > 1:
|
||||||
|
context_output, context_lse = \
|
||||||
|
self._context_parallel_compute_prefill_context(
|
||||||
|
q, kv_c_and_k_pe_cache, attn_metadata,
|
||||||
|
k_scale=None, dcp_world_size=self.dcp_world_size)
|
||||||
|
else:
|
||||||
|
context_output, context_lse = \
|
||||||
|
self._compute_prefill_context(
|
||||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||||
|
|
||||||
output = torch.empty_like(suffix_output)
|
output = torch.empty_like(suffix_output)
|
||||||
@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
ql_nope: torch.Tensor,
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# same expert outputs.
|
# same expert outputs.
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
|
if self.dcp_world_size is None:
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
|
||||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_toks = attn_metadata.num_actual_tokens
|
||||||
@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
layer._q_scale)
|
layer._q_scale)
|
||||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||||
|
|
||||||
output[:num_decode_tokens] = self._forward_decode(
|
decode_q = (decode_ql_nope, decode_q_pe)
|
||||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
if self.dcp_world_size > 1:
|
||||||
|
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||||
|
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
|
||||||
|
decode_q = torch.cat(decode_q, dim=-1)
|
||||||
|
# decode_q do allgather in head dim.
|
||||||
|
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
|
||||||
|
|
||||||
|
# call decode attn
|
||||||
|
attn_out, lse = self._forward_decode(decode_q, kv_cache,
|
||||||
|
attn_metadata, layer)
|
||||||
|
|
||||||
|
# recorect dcp attn_out with lse.
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
assert lse is not None, (
|
||||||
|
"For a mla backend want to enable"
|
||||||
|
"DCP, it is mandatory that the corresponding decode attn"
|
||||||
|
"kernel return the softmax lse.")
|
||||||
|
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||||
|
|
||||||
|
# v_up projection
|
||||||
|
output[:num_decode_tokens] = self._v_up_proj(attn_out)
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
self._workspace.get_buf(),
|
self._workspace.get_buf(),
|
||||||
self.scale, self._num_kv_splits)
|
self.scale, self._num_kv_splits)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return o
|
||||||
|
|
||||||
# TODO: Currently we leave it here only for backup in case something is
|
# TODO: Currently we leave it here only for backup in case something is
|
||||||
# wrong with the new SM100 CUTLASS MLA kernel
|
# wrong with the new SM100 CUTLASS MLA kernel
|
||||||
@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata.decode.seq_lens,
|
attn_metadata.decode.seq_lens,
|
||||||
attn_metadata.decode.block_table, self.scale)
|
attn_metadata.decode.block_table, self.scale)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return o
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if type(q) is tuple:
|
||||||
|
q_nope, q_pe = q
|
||||||
|
else:
|
||||||
|
q_nope, q_pe = torch.split(
|
||||||
|
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
if self._use_old_cutlass_mla:
|
if self._use_old_cutlass_mla:
|
||||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||||
# testing
|
# testing
|
||||||
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||||
attn_metadata)
|
attn_metadata), None
|
||||||
|
|
||||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||||
attn_metadata)
|
attn_metadata), None
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
|||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: FlashAttnMLAMetadata,
|
attn_metadata: FlashAttnMLAMetadata,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
|
if type(q) is tuple:
|
||||||
|
q_nope, q_pe = q
|
||||||
|
else:
|
||||||
|
q_nope, q_pe = torch.split(
|
||||||
|
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FP8 FlashAttention MLA not yet supported")
|
"FP8 FlashAttention MLA not yet supported")
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: FlashMLAMetadata,
|
attn_metadata: FlashMLAMetadata,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
if type(q) is tuple:
|
||||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
q = torch.cat(q, dim=-1)
|
||||||
|
|
||||||
o, _ = flash_mla_with_kvcache(
|
assert isinstance(q, torch.Tensor)
|
||||||
q=q,
|
o, lse = flash_mla_with_kvcache(
|
||||||
|
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
block_table=attn_metadata.decode.block_table,
|
block_table=attn_metadata.decode.block_table,
|
||||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||||
@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
descale_k=layer._k_scale.reshape(1),
|
descale_k=layer._k_scale.reshape(1),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return o, lse
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: AiterMLAMetadata,
|
attn_metadata: AiterMLAMetadata,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
B = q_nope.shape[0]
|
if type(q) is tuple:
|
||||||
|
q = torch.cat(q, dim=-1)
|
||||||
|
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
assert isinstance(q, torch.Tensor)
|
||||||
|
B = q.shape[0]
|
||||||
o = torch.zeros(B,
|
o = torch.zeros(B,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
attn_metadata.decode.paged_kv_indices,
|
attn_metadata.decode.paged_kv_indices,
|
||||||
attn_metadata.decode.paged_kv_last_page_len)
|
attn_metadata.decode.paged_kv_last_page_len)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return o, None
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
q_pe: torch.Tensor,
|
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||||
|
|
||||||
B = q_nope.shape[0]
|
if type(q) is tuple:
|
||||||
|
q = torch.cat(q, dim=-1)
|
||||||
|
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
assert isinstance(q, torch.Tensor)
|
||||||
|
B = q.shape[0]
|
||||||
o = torch.zeros(B,
|
o = torch.zeros(B,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata.decode.seq_lens, attn_logits,
|
attn_metadata.decode.seq_lens, attn_logits,
|
||||||
num_kv_splits, self.scale, PAGE_SIZE)
|
num_kv_splits, self.scale, PAGE_SIZE)
|
||||||
|
|
||||||
return self._v_up_proj(o)
|
return o, None
|
||||||
|
|||||||
@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
|
dcp_world_size: int,
|
||||||
):
|
):
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_group_id=i,
|
kv_cache_group_id=i,
|
||||||
|
dcp_world_size=dcp_world_size,
|
||||||
) for i, kv_cache_group in enumerate(
|
) for i, kv_cache_group in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups))
|
self.kv_cache_config.kv_cache_groups))
|
||||||
|
|
||||||
@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||||
use_eagle: bool, enable_kv_cache_events: bool):
|
use_eagle: bool, enable_kv_cache_events: bool,
|
||||||
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
|
dcp_world_size: int):
|
||||||
enable_kv_cache_events)
|
super().__init__(kv_cache_config,
|
||||||
|
max_model_len,
|
||||||
|
use_eagle,
|
||||||
|
False,
|
||||||
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
self.num_single_type_manager = len(self.single_type_managers)
|
self.num_single_type_manager = len(self.single_type_managers)
|
||||||
|
|
||||||
def get_num_common_prefix_blocks(self, request_id: str,
|
def get_num_common_prefix_blocks(self, request_id: str,
|
||||||
@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
|
|
||||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||||
use_eagle: bool, enable_caching: bool,
|
use_eagle: bool, enable_caching: bool,
|
||||||
enable_kv_cache_events: bool):
|
enable_kv_cache_events: bool, dcp_world_size: int):
|
||||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
super().__init__(kv_cache_config,
|
||||||
enable_caching, enable_kv_cache_events)
|
max_model_len,
|
||||||
|
use_eagle,
|
||||||
|
enable_caching,
|
||||||
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||||
0].kv_cache_spec
|
0].kv_cache_spec
|
||||||
self.block_size = self.kv_cache_spec.block_size
|
self.block_size = self.kv_cache_spec.block_size
|
||||||
|
self.dcp_world_size = dcp_world_size
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
self.block_size *= dcp_world_size
|
||||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||||
|
|
||||||
@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_spec=self.kv_cache_spec,
|
kv_cache_spec=self.kv_cache_spec,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
|
dcp_world_size=self.dcp_world_size,
|
||||||
)
|
)
|
||||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||||
|
|
||||||
@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
|
|
||||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||||
use_eagle: bool, enable_caching: bool,
|
use_eagle: bool, enable_caching: bool,
|
||||||
enable_kv_cache_events: bool):
|
enable_kv_cache_events: bool, dcp_world_size: int):
|
||||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
super().__init__(kv_cache_config,
|
||||||
enable_caching, enable_kv_cache_events)
|
max_model_len,
|
||||||
|
use_eagle,
|
||||||
|
enable_caching,
|
||||||
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
|
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||||
self.verify_and_split_kv_cache_groups()
|
self.verify_and_split_kv_cache_groups()
|
||||||
|
|
||||||
def verify_and_split_kv_cache_groups(self) -> None:
|
def verify_and_split_kv_cache_groups(self) -> None:
|
||||||
@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
return hit_blocks, hit_length
|
return hit_blocks, hit_length
|
||||||
|
|
||||||
|
|
||||||
def get_kv_cache_coordinator(
|
def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig,
|
||||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
max_model_len: int, use_eagle: bool,
|
||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
enable_kv_cache_events: bool,
|
||||||
|
dcp_world_size: int) -> KVCacheCoordinator:
|
||||||
if not enable_caching:
|
if not enable_caching:
|
||||||
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
|
return KVCacheCoordinatorNoPrefixCache(kv_cache_config,
|
||||||
|
max_model_len,
|
||||||
use_eagle,
|
use_eagle,
|
||||||
enable_kv_cache_events)
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
return UnitaryKVCacheCoordinator(kv_cache_config,
|
||||||
use_eagle, enable_caching,
|
max_model_len,
|
||||||
enable_kv_cache_events)
|
use_eagle,
|
||||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
enable_caching,
|
||||||
enable_caching, enable_kv_cache_events)
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
|
return HybridKVCacheCoordinator(kv_cache_config,
|
||||||
|
max_model_len,
|
||||||
|
use_eagle,
|
||||||
|
enable_caching,
|
||||||
|
enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size)
|
||||||
|
|||||||
@ -91,6 +91,7 @@ class KVCacheManager:
|
|||||||
use_eagle: bool = False,
|
use_eagle: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
|
|
||||||
@ -109,12 +110,20 @@ class KVCacheManager:
|
|||||||
self.block_size = kv_cache_config.kv_cache_groups[
|
self.block_size = kv_cache_config.kv_cache_groups[
|
||||||
0].kv_cache_spec.block_size
|
0].kv_cache_spec.block_size
|
||||||
|
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
|
# Note(hc): need revisit. When both DCP and any future
|
||||||
|
# PCP are enabled, the block_size may need to be scaled
|
||||||
|
# by a factor of dcp_size × pcp_size?
|
||||||
|
self.block_size *= dcp_world_size
|
||||||
|
|
||||||
self.coordinator = get_kv_cache_coordinator(
|
self.coordinator = get_kv_cache_coordinator(
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
enable_caching=self.enable_caching,
|
enable_caching=self.enable_caching,
|
||||||
enable_kv_cache_events=enable_kv_cache_events,
|
enable_kv_cache_events=enable_kv_cache_events,
|
||||||
|
dcp_world_size=dcp_world_size,
|
||||||
)
|
)
|
||||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||||
self.block_pool = self.coordinator.block_pool
|
self.block_pool = self.coordinator.block_pool
|
||||||
|
|||||||
@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
||||||
|
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||||
|
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
logger.info(
|
||||||
|
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||||
|
vllm_config.parallel_config.decode_context_parallel_size)
|
||||||
|
|
||||||
num_tokens_str = f"{num_tokens:,}"
|
num_tokens_str = f"{num_tokens:,}"
|
||||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||||
|
|||||||
@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
self.block_size = self.cache_config.block_size
|
self.block_size = self.cache_config.block_size
|
||||||
|
|
||||||
|
self.dcp_world_size = \
|
||||||
|
vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
# Note(hc): The scheduler’s block_size must be multiplied
|
||||||
|
# by dcp_world_size, since block hashes are computed on the
|
||||||
|
# original full token sequence at a granularity of
|
||||||
|
# original_block_size × dcp_world_size.
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
self.block_size *= self.dcp_world_size
|
||||||
|
|
||||||
# req_id -> Request
|
# req_id -> Request
|
||||||
self.requests: dict[str, Request] = {}
|
self.requests: dict[str, Request] = {}
|
||||||
# Scheduling policy
|
# Scheduling policy
|
||||||
@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
|
dcp_world_size=self.dcp_world_size,
|
||||||
)
|
)
|
||||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_group_id: int,
|
kv_cache_group_id: int,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the SingleTypeKVCacheManager.
|
Initializes the SingleTypeKVCacheManager.
|
||||||
@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
block_pool: The block pool.
|
block_pool: The block pool.
|
||||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
self.dcp_world_size = dcp_world_size
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
self.block_size *= dcp_world_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.block_pool = block_pool
|
self.block_pool = block_pool
|
||||||
|
|
||||||
@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks that is not longer than
|
Get the longest cache hit prefix of the blocks that is not longer than
|
||||||
@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||||
@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
"and chunked local attention groups"
|
"and chunked local attention groups"
|
||||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
[] for _ in range(len(kv_cache_group_ids)))
|
[] for _ in range(len(kv_cache_group_ids)))
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
block_size = kv_cache_spec.block_size
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
block_size *= dcp_world_size
|
||||||
|
max_num_blocks = max_length // block_size
|
||||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
# in the cached_block_hash_to_id, the following block hashes are
|
# in the cached_block_hash_to_id, the following block hashes are
|
||||||
@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||||
"SlidingWindowManager can only be used for sliding window groups")
|
"SlidingWindowManager can only be used for sliding window groups")
|
||||||
|
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||||
|
|
||||||
# The number of contiguous blocks needed for prefix cache hit.
|
# The number of contiguous blocks needed for prefix cache hit.
|
||||||
# -1 since the input token itself is also included in the window
|
# -1 since the input token itself is also included in the window
|
||||||
@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
For chunked local attention, we need to find the longest cache hit
|
For chunked local attention, we need to find the longest cache hit
|
||||||
@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
"chunked local attention groups")
|
"chunked local attention groups")
|
||||||
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
|
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
|
||||||
"eagle + chunked local attention.")
|
"eagle + chunked local attention.")
|
||||||
|
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
if max_length > 0:
|
if max_length > 0:
|
||||||
local_attention_start_idx = (max_length //
|
local_attention_start_idx = (max_length //
|
||||||
@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
MambaSpec), ("MambaManager can only be used for mamba groups")
|
MambaSpec), ("MambaManager can only be used for mamba groups")
|
||||||
|
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||||
# Prefix caching is not supported for mamba now. Always return empty
|
# Prefix caching is not supported for mamba now. Always return empty
|
||||||
# list.
|
# list.
|
||||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||||
"CrossAttentionManager can only be used for cross-attention groups"
|
"CrossAttentionManager can only be used for cross-attention groups"
|
||||||
|
|||||||
@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
dcp_world_size = \
|
||||||
|
vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
# Note(hc): each dcp rank only need save
|
||||||
|
# (max_model_len//dcp_world_size) tokens locally.
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec):
|
|||||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
|
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
|
||||||
|
"DCP not support sliding window."
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
max_num_batched_tokens = (
|
max_num_batched_tokens = (
|
||||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
@ -50,6 +51,13 @@ class BlockTable:
|
|||||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
try:
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
|
||||||
def append_row(
|
def append_row(
|
||||||
self,
|
self,
|
||||||
@ -89,6 +97,29 @@ class BlockTable:
|
|||||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||||
# here because M (max_model_len) is not necessarily divisible by
|
# here because M (max_model_len) is not necessarily divisible by
|
||||||
# block_size.
|
# block_size.
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
# Note(hc): The DCP implement store kvcache with a interleave
|
||||||
|
# style, the kvcache for the token whose token_idx is i is
|
||||||
|
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||||
|
|
||||||
|
# Use a "virtual block" which equals to world_size * block_size
|
||||||
|
# for block_table_indices calculation.
|
||||||
|
virtual_block_size = self.block_size * self.dcp_world_size
|
||||||
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||||
|
positions // virtual_block_size)
|
||||||
|
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||||
|
# Use virtual_block_size for mask calculation, which marks local
|
||||||
|
# tokens.
|
||||||
|
virtual_block_offsets = positions % virtual_block_size
|
||||||
|
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||||
|
# Calcuate local block_offsets
|
||||||
|
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||||
|
# Calcuate slot_mapping
|
||||||
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
|
# Write final slots, use -1 for not-local
|
||||||
|
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
||||||
|
mask, slot_mapping, -1)
|
||||||
|
else:
|
||||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||||
positions // self.block_size)
|
positions // self.block_size)
|
||||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||||
@ -128,9 +159,19 @@ class MultiGroupBlockTable:
|
|||||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||||
max_num_batched_tokens: int, pin_memory: bool,
|
max_num_batched_tokens: int, pin_memory: bool,
|
||||||
device: torch.device, block_sizes: list[int]) -> None:
|
device: torch.device, block_sizes: list[int]) -> None:
|
||||||
|
# Note(hc): each dcp rank only store
|
||||||
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||||
|
# so the block_size which used for calc max_num_blocks_per_req
|
||||||
|
# must be multiplied by dcp_world_size.
|
||||||
|
try:
|
||||||
|
dcp_world_size = get_dcp_group().world_size
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
dcp_world_size = 1
|
||||||
|
|
||||||
self.block_tables = [
|
self.block_tables = [
|
||||||
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
|
BlockTable(block_size, max_num_reqs,
|
||||||
block_size),
|
cdiv(max_model_len, block_size * dcp_world_size),
|
||||||
max_num_batched_tokens, pin_memory, device)
|
max_num_batched_tokens, pin_memory, device)
|
||||||
for block_size in block_sizes
|
for block_size in block_sizes
|
||||||
]
|
]
|
||||||
|
|||||||
@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||||
get_dtype_size, is_pin_memory_available, round_up,
|
get_dtype_size, is_pin_memory_available, round_up,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
|
from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
model_config.is_multimodal_raw_input_only_model)
|
model_config.is_multimodal_raw_input_only_model)
|
||||||
|
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||||
|
|
||||||
@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.reorder_batch_threshold is not None:
|
if self.reorder_batch_threshold is not None:
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
assert self.reorder_batch_threshold == 1, \
|
||||||
|
"DCP not support reorder_batch_threshold > 1 now."
|
||||||
reorder_batch_to_split_decodes_and_prefills(
|
reorder_batch_to_split_decodes_and_prefills(
|
||||||
self.input_batch,
|
self.input_batch,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||||
copy_kv_blocks)
|
copy_kv_blocks)
|
||||||
|
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
assert self.attn_groups[0][0].backend is FlashMLABackend, (
|
||||||
|
"DCP only support flashmla now."
|
||||||
|
"For a mla backend want to enable DCP, it is mandatory that the"
|
||||||
|
"corresponding decode attn kernel return the softmax lse.")
|
||||||
|
|
||||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
Add encoder-only layers to the KV cache config.
|
Add encoder-only layers to the KV cache config.
|
||||||
|
|||||||
@ -616,7 +616,9 @@ def init_worker_distributed_environment(
|
|||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
distributed_init_method, local_rank, backend)
|
distributed_init_method, local_rank, backend)
|
||||||
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.decode_context_parallel_size)
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|||||||
@ -539,8 +539,10 @@ def init_worker_distributed_environment(
|
|||||||
init_distributed_environment(parallel_config.world_size, rank,
|
init_distributed_environment(parallel_config.world_size, rank,
|
||||||
distributed_init_method, local_rank,
|
distributed_init_method, local_rank,
|
||||||
current_platform.dist_backend)
|
current_platform.dist_backend)
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size,
|
||||||
|
parallel_config.decode_context_parallel_size)
|
||||||
|
|
||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user