diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ad240023a003..b0d4c4456d33 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -837,7 +837,7 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins -- label: Pipeline Parallelism Test # 45min +- label: Pipeline + Context Parallelism Test # 45min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" @@ -851,6 +851,7 @@ steps: commands: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py + # - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support - label: LoRA TP Test (Distributed) # 17 min timeout_in_minutes: 30 diff --git a/csrc/cache.h b/csrc/cache.h index e8e069aefd9c..fd230bec27fc 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, const std::string& kv_cache_dtype, torch::Tensor& scale); -void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, - torch::Tensor& cp_local_token_select_indices, - torch::Tensor& kv_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - torch::Tensor& scale); - // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index fbb022464ef2..80b4c47c5547 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } -template -__global__ void cp_fused_concat_and_cache_mla_kernel( - const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim] - const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens] - cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size, // - const float* scale // -) { - const int64_t token_idx = cp_local_token_select_indices[blockIdx.x]; - const int64_t slot_idx = slot_mapping[blockIdx.x]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - return; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - - auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, - int src_stride, int dst_stride, int size, int offset) { - for (int i = threadIdx.x; i < size; i += blockDim.x) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { - dst[dst_idx] = src[src_idx]; - } else { - dst[dst_idx] = - fp8::scaled_convert(src[src_idx], *scale); - } - } - }; - - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); -} - } // namespace vllm // KV_T is the data type of key and value tensors. @@ -554,20 +509,6 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); -// KV_T is the data type of key and value tensors. -// CACHE_T is the stored data type of kv-cache. -// KV_DTYPE is the real data type of kv-cache. -#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ - vllm::cp_fused_concat_and_cache_mla_kernel \ - <<>>( \ - reinterpret_cast(kv_c.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - cp_local_token_select_indices.data_ptr(), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, entry_stride, \ - kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ - reinterpret_cast(scale.data_ptr())); - void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -606,50 +547,6 @@ void concat_and_cache_mla( CALL_CONCAT_AND_CACHE_MLA); } -// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel -// calls into one: -// k_c_normed.index_select(0, cp_local_token_select_indices) + \ -// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \ -// concat_and_cache_mla. -void cp_fused_concat_and_cache_mla( - torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_total_tokens, pe_dim] - torch::Tensor& cp_local_token_select_indices, // [num_tokens] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - // NOTE(woosuk): In vLLM V1, key.size(0) can be different from - // slot_mapping.size(0) because of padding for CUDA graphs. - // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because - // both include padding. - // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) - // since key includes padding for CUDA graphs, while slot_mapping does not. - // In this case, slot_mapping.size(0) represents the actual number of tokens - // before padding. - // For compatibility with both cases, we use slot_mapping.size(0) as the - // number of tokens. - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CP_FUSED_CONCAT_AND_CACHE_MLA); -} - namespace vllm { template diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b769c09adc0f..95fb5b197f53 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale) -> ()"); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); - cache_ops.def( - "cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe," - " Tensor cp_local_token_select_indices," - " Tensor! kv_cache," - " Tensor slot_mapping," - " str kv_cache_dtype," - " Tensor scale) -> ()"); - cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA, - &cp_fused_concat_and_cache_mla); - // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py new file mode 100644 index 000000000000..23be703a3068 --- /dev/null +++ b/tests/distributed/test_context_parallel.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import RunnerOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_context_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + dcp_size: int + eager_mode: bool + chunked_prefill: bool + + +class CPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class CPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + runner: RunnerOption + test_options: CPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 4, + pp_base: int = 1, + dcp_base: int = 1, + multi_node_only: bool = False, + runner: RunnerOption = "auto", + load_format: Optional[str] = None, + ): + parallel_setups = [] + for eager_mode_val in [False]: + for pp_multiplier in [1]: + for dcp_multiplier in [2, 4]: + for chunked_prefill_val in [True]: + parallel_setups.append( + ParallelSetup(tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + dcp_size=dcp_multiplier * dcp_base, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val)) + return CPTestSettings( + parallel_setups=parallel_setups, + distributed_backends=["mp"], + vllm_major_versions=["1"], + runner=runner, + test_options=CPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.runner, opts) + + +def _compare_cp_with_tp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate"], + is_multimodal: bool, +): + ( + tp_size, + pp_size, + dcp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if runner != "auto": + common_args.extend(["--runner", runner]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + cp_env = tp_env = { + "VLLM_USE_V1": + vllm_major_version, # Note(hc): DCP only support V1 engine only + } + + cp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--decode-context-parallel-size", + str(dcp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + try: + compare_two_settings(model_id, + cp_args, + tp_args, + cp_env, + tp_env, + method=method, + max_wait_seconds=720) + except Exception: + testing_ray_compiled_graph = cp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] + "deepseek-ai/DeepSeek-V2-Lite-Chat": CPTestSettings.detailed(), +} + +CP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "deepseek-ai/DeepSeek-V2-Lite-Chat", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "runner", "test_options"), + [ + params for model_id, settings in CP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in CP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_cp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + runner: RunnerOption, + test_options: CPTestOptions, + num_gpus_available, +): + _compare_cp_with_tp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bb67d5790aaa..545f4cb48bf4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1625,20 +1625,6 @@ def concat_and_cache_mla( scale) -def cp_fused_concat_and_cache_mla( - kv_c: torch.Tensor, - k_pe: torch.Tensor, - cp_local_token_select_indices: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla( - kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping, - kv_cache_dtype, scale) - - def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py new file mode 100644 index 000000000000..189b57e8e8b8 --- /dev/null +++ b/vllm/attention/ops/common.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.triton_utils import tl, triton + + +@triton.jit +def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, + vlse_ptr, outputs_stride_B, outputs_stride_H, + outputs_stride_D, lses_stride_N, lses_stride_B, + lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + cp, batch, q_heads, v_head_dim + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse) + lse_max = tl.max(lse, axis=0) + lse -= lse_max + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = batch_idx * outputs_stride_B + \ + head_idx * outputs_stride_H + \ + d_offsets * outputs_stride_D + + # correct output + lse_offset = lse_idx * lses_stride_N + batch_idx * \ + lses_stride_B + head_idx * lses_stride_H + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float('inf')), + -float('inf'), lse_finally) + factor = tl.exp(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + +class CPTritonContext: + """ The CPTritonContext is used to avoid recompilation of the Triton JIT. + """ + + def __init__(self): + self.inner_kernel = None + + def call_kernel(self, kernel, grid, *regular_args, **const_args): + if self.inner_kernel is None: + self.inner_kernel = kernel[grid](*regular_args, **const_args) + else: + self.inner_kernel[grid](*regular_args) + + +def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int, + ctx: CPTritonContext): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + output: [ B, H, D ] + lses : [ N, B, H ] + Return: + output: [ B, H, D ] + lse : [ B, H ] + """ + if ctx is None: + ctx = CPTritonContext() + + lse = torch.empty_like(lses[0]) + + grid = (out.shape[0], out.shape[1], 1) + regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), + cp_rank) + const_args = { + "HEAD_DIM": out.shape[-1], + "N_ROUNDED": lses.shape[0], + } + + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, + **const_args) + return out, lse + + +def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if cp_group.world_size == 1: + return cp_attn_out + + if ctx is None: + ctx = CPTritonContext() + + lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device) + + cp_attn_lse = cp_attn_lse.contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) + out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() + out = cp_group.reduce_scatter(out, dim=1) + return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 564042cf8eb1..2c3e8c42400c 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -105,7 +105,9 @@ def flash_mla_with_kvcache( descale_q, descale_k, ) - return out, softmax_lse + + # Note(hc): need revisit when we support DCP with decode query_len > 1. + return out.squeeze(1), softmax_lse.squeeze(-1) # diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9d4594bab3c1..fb8e30996ea3 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -170,6 +170,11 @@ class ParallelConfig: Set to be private as it's not intended to be configured by users. """ + decode_context_parallel_size: int = 1 + """Number of decode context parallel groups, because the world size does + not change by dcp, it simply reuse the GPUs of TP group, and tp_size + needs to be divisible by dcp_size.""" + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fc96c2ac926b..522dfc8d8b5a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -904,6 +904,18 @@ def get_tensor_model_parallel_group(): return get_tp_group() +_DCP: Optional[GroupCoordinator] = None + + +def get_dcp_group() -> GroupCoordinator: + assert _DCP is not None, ( + "decode context model parallel group is not initialized") + return _DCP + + +# kept for backward compatibility +get_context_model_parallel_group = get_dcp_group + _PP: Optional[GroupCoordinator] = None _DP: Optional[GroupCoordinator] = None @@ -1034,6 +1046,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """ @@ -1098,6 +1111,23 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") + # Build the DCP model-parallel groups. + global _DCP + assert _DCP is None, ( + "decode context model parallel group is already initialized") + # Note(hc): In the current implementation of decode context parallel, + # dcp_size must not exceed tp_size, because the world size does not + # change by DCP, it simply reuse the GPUs of TP group, and split one + # TP group into tp_size//dcp_size DCP groups. + group_ranks = all_ranks.reshape( + -1, decode_context_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DCP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp") + # Build the pipeline model-parallel groups. global _PP assert _PP is None, ( @@ -1141,6 +1171,7 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1151,7 +1182,8 @@ def ensure_model_parallel_initialized( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, backend) + pipeline_model_parallel_size, + decode_context_model_parallel_size, backend) return assert ( @@ -1226,6 +1258,16 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_decode_context_model_parallel_world_size(): + """Return world size for the decode context model parallel group.""" + return get_dcp_group().world_size + + +def get_decode_context_model_parallel_rank(): + """Return my rank for the decode context model parallel group.""" + return get_dcp_group().rank_in_group + + def get_node_count() -> int: """Return the total number of nodes in the distributed environment. """ assert _NODE_COUNT is not None, ( @@ -1246,6 +1288,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DCP + if _DCP: + _DCP.destroy() + _DCP = None + global _DP if _DP: _DP.destroy() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 71ee90040f37..d96654ecfa46 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -306,6 +306,8 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + decode_context_parallel_size: int = \ + ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -636,6 +638,9 @@ class EngineArgs: **parallel_kwargs["pipeline_parallel_size"]) parallel_group.add_argument("--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument( + "--decode-context-parallel-size", "-dcp", + **parallel_kwargs["decode_context_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( @@ -1156,6 +1161,17 @@ class EngineArgs: # global layers in interleaved sliding window models. sliding_window = model_config.get_sliding_window() + # Note(hc): In the current implementation of decode context + # parallel(DCP), tp_size needs to be divisible by dcp_size, + # because the world size does not change by dcp, it simply + # reuse the GPUs of TP group, and split one TP group into + # tp_size//dcp_size DCP groups. + assert self.tensor_parallel_size % self.decode_context_parallel_size \ + == 0, ( + f"tp_size={self.tensor_parallel_size} must be divisible by" + f"dcp_size={self.decode_context_parallel_size}." + ) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1306,6 +1322,7 @@ class EngineArgs: distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, + decode_context_parallel_size=self.decode_context_parallel_size, ) speculative_config = self.create_speculative_config( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 9696b6c0913c..090ebf93840d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig -from vllm.distributed.parallel_state import is_global_first_rank +from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -323,6 +324,13 @@ class MLACommonPrefillMetadata: seq_lens: torch.Tensor workspace: torch.Tensor + # for mla DCP + cp_chunk_seq_lens: Optional[list[list[int]]] = None + origin_context_lens: Optional[list[int]] = None + cp_cu_seq_lens: Optional[torch.Tensor] = None + chunk_size: Optional[int] = None + cu_seq_lens_lst: Optional[list[list[int]]] = None + block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 # Dont try to access the runner on AMD if self.aot_schedule: @@ -465,12 +480,27 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): 128 * 1024) assert self.chunked_prefill_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + if self.dcp_world_size > 1: + # Note(hc): The local kvcache is incomplete when DCP is triggered, + # an additional kvcache allgather across the DCP group is therefore + # required, so the workspace has to be enlarged by 1/DCP relative + # to the original TP allocation. + assert self.chunked_prefill_workspace_size % \ + self.dcp_world_size == 0 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + else: + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() @@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.reorder_batch_threshold) + # Note(hc): update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + seq_lens[:num_decodes] = seq_lens[:num_decodes] \ + // self.dcp_world_size + (self.dcp_rank <= \ + (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] + # Note(hc): The context lengths in the perspective of dcp rank0. + cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / + self.dcp_world_size).int() + origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ @@ -691,20 +731,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + if self.dcp_world_size > 1: + # Note(hc): The above max_context_chunk already enforces + # block_size alignment, DCP just need the block_size can + # be divisible by dcp_world_size, because DCP use + # cp_gather_cache which not require `cp_chunk_starts` + # aligned to page_size. + assert max_context_chunk % self.dcp_world_size == 0 + cp_max_context_chunk = max_context_chunk // \ + self.dcp_world_size + cp_chunk_starts = \ + torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, num_prefills) \ + * cp_max_context_chunk + cp_chunk_ends = torch.min( + cp_context_lens_cpu.unsqueeze(0), + cp_chunk_starts + cp_max_context_chunk) + cp_chunk_seq_lens = (cp_chunk_ends - + cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros(num_chunks, + num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata_cls = \ CudnnPrefillMetadata.ChunkedContextMetadata \ if self._use_cudnn_prefill else \ MLACommonPrefillMetadata.ChunkedContextMetadata - - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, - ) + if self.dcp_world_size > 1: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=cp_chunk_starts.to(device, non_blocking=True), + seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), + origin_context_lens=origin_context_lens, + cp_cu_seq_lens=cp_cu_seq_lens_cpu \ + .to(device, non_blocking=True), + chunk_size=max_context_chunk, + cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), + ) + else: + chunked_context_metadata = \ + chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu \ + .to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens @@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return attn_metadata +def reorg_kvcache( + allgatered_kv_c_normed: torch.Tensor, + allgatered_k_pe: torch.Tensor, + cp_chunk_seq_lens_lst: list[int], + origin_context_lens: list[int], + cp_world_size: int, + sum_seq_len: int, + max_seq_len: int, + chunk_size: int, + chunk_idx: int, + toks: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + reorg kvcache after cp local gather to tp layout for attn kernel. + + Args: + cp_chunk_seq_lens_lst: chunk context lengths under CP. + origin_context_lens: origin full context lengths under CP. + cp_world_size: CP size. + sum_seq_len: the sum of cp_chunk_seq_lens_lst. + max_seq_len: the max value of cp_chunk_seq_lens_lst. + chunk_size: equals to max_context_chunk from + chunked_context_metadata building. + chunk_idx: chunk idx of chunked_prefill. + toks: the number of tokens for local gather cache. + """ + kv_c_segments = [] + k_pe_segments = [] + src_token_idx = 0 + max_seq_len_check = 0 + for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, + origin_context_lens): + chunk_context_len = chunk_size + if cp_chunk_seq_len != 0: + chunk_context_len = min( + chunk_context_len, origin_context_len - chunk_size * chunk_idx) + cp_target_rank = (chunk_context_len - 1) % cp_world_size + cur_seq_len = 0 + for rank in range(cp_world_size): + if rank > cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + else: + real_cp_chunk_seq_len = cp_chunk_seq_len + if real_cp_chunk_seq_len: + kv_c_segment = allgatered_kv_c_normed[rank * toks + + src_token_idx:rank * + toks + src_token_idx + + real_cp_chunk_seq_len] + k_pe_segment = allgatered_k_pe[rank * toks + + src_token_idx:rank * toks + + src_token_idx + + real_cp_chunk_seq_len] + kv_c_segments.append(kv_c_segment) + k_pe_segments.append(k_pe_segment) + cur_seq_len += real_cp_chunk_seq_len + max_seq_len_check = max(max_seq_len_check, cur_seq_len) + src_token_idx += cp_chunk_seq_len + reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) + reorganized_k_pe = torch.cat(k_pe_segments, dim=0) + assert reorganized_kv_c_normed.shape[0] == sum_seq_len + assert reorganized_k_pe.shape[0] == sum_seq_len + assert max_seq_len_check == max_seq_len + return reorganized_kv_c_normed, reorganized_k_pe + + class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) + self.dcp_world_size: Optional[int] = None + def _flash_attn_varlen_diff_headdims(self, q, k, @@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return output, output_lse + def _context_parallel_compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + k_scale: torch.Tensor, + dcp_world_size: int, + ): + assert k_scale is None, "DCP not support sacled kvcache now." + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.origin_context_lens is not None + assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None + assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None + + output = None + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace + + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + ops.cp_gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + batch_size=attn_metadata.num_prefills, + seq_starts=prefill_metadata.chunked_context.starts[i], + ) + # workspace + # |------- N tokens --------|--------- N*dcp_size tokens ----------| + # |<- use for loca_gather ->|<--------- use for allgather -------->| + allgather_offset = workspace.shape[0] // (dcp_world_size + 1) + assert allgather_offset * (dcp_world_size + + 1) == workspace.shape[0] + assert toks <= allgather_offset + local_gathered_kvcache = workspace[:toks] + cur_allgather_workspace = workspace[ + allgather_offset:allgather_offset * (1 + dcp_world_size)] + assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] + cur_allgather_kvcache = cur_allgather_workspace[:toks * + dcp_world_size] + cur_allgather_kvcache.copy_(get_dcp_group().all_gather( + local_gathered_kvcache, dim=0)) + assert cur_allgather_kvcache.shape[ + -1] == self.kv_lora_rank + self.qk_rope_head_dim + allgatered_kv_c_normed, allgatered_k_pe = \ + cur_allgather_kvcache.unsqueeze( + 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed, k_pe = reorg_kvcache( + allgatered_kv_c_normed, + allgatered_k_pe, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. + cp_chunk_seq_lens[i], + origin_context_lens=prefill_metadata.chunked_context. + origin_context_lens, + cp_world_size=dcp_world_size, + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] + [-1], + max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], + chunk_size=prefill_metadata.chunked_context.chunk_size, + chunk_idx=i, + toks=toks) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, + q=q, + k=k, + v=v, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + def _forward_prefill( self, q: torch.Tensor, @@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): k_scale: torch.Tensor, ) -> torch.Tensor: assert attn_metadata.prefill is not None + assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ @@ -1181,8 +1437,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + if self.dcp_world_size > 1: + context_output, context_lse = \ + self._context_parallel_compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, + k_scale=None, dcp_world_size=self.dcp_world_size) + else: + context_output, context_lse = \ + self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): @abstractmethod def _forward_decode( self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( @@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) + if self.dcp_world_size is None: + self.dcp_world_size = get_dcp_group().world_size + fp8_attention = self.kv_cache_dtype.startswith("fp8") num_actual_toks = attn_metadata.num_actual_tokens @@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): layer._q_scale) decode_q_pe = decode_q_pe.reshape(q_pe_shape) - output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer) + decode_q = (decode_ql_nope, decode_q_pe) + if self.dcp_world_size > 1: + assert not fp8_attention, "DCP not support fp8 kvcache now." + # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) + decode_q = torch.cat(decode_q, dim=-1) + # decode_q do allgather in head dim. + decode_q = get_dcp_group().all_gather(decode_q, dim=1) + # call decode attn + attn_out, lse = self._forward_decode(decode_q, kv_cache, + attn_metadata, layer) + + # recorect dcp attn_out with lse. + if self.dcp_world_size > 1: + assert lse is not None, ( + "For a mla backend want to enable" + "DCP, it is mandatory that the corresponding decode attn" + "kernel return the softmax lse.") + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + + # v_up projection + output[:num_decode_tokens] = self._v_up_proj(attn_out) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 705307d4dea3..95dce8d8e2ee 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): self._workspace.get_buf(), self.scale, self._num_kv_splits) - return self._v_up_proj(o) + return o # TODO: Currently we leave it here only for backup in case something is # wrong with the new SM100 CUTLASS MLA kernel @@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_metadata.decode.block_table, self.scale) - return self._v_up_proj(o) + return o def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive # testing return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) + attn_metadata), None return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) + attn_metadata), None diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 0e08307ddf84..e2a63c2f577e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashAttnMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( "FP8 FlashAttention MLA not yet supported") diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index df617ab7a8ea..11c91b8a0650 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + if type(q) is tuple: + q = torch.cat(q, dim=-1) - o, _ = flash_mla_with_kvcache( - q=q, + assert isinstance(q, torch.Tensor) + o, lse = flash_mla_with_kvcache( + q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, @@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): descale_k=layer._k_scale.reshape(1), ) - return self._v_up_proj(o) + return o, lse diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 42670093daa9..fc6b1998e8eb 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_last_page_len) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f2974ed668d9..d692b00d78b4 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch @@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") - B = q_nope.shape[0] + if type(q) is tuple: + q = torch.cat(q, dim=-1) - q = torch.cat([q_nope, q_pe], dim=-1) + assert isinstance(q, torch.Tensor) + B = q.shape[0] o = torch.zeros(B, self.num_heads, self.kv_lora_rank, @@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj(o) + return o, None diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 9421341f990c..86771060c409 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -24,6 +24,7 @@ class KVCacheCoordinator(ABC): use_eagle: bool, enable_caching: bool, enable_kv_cache_events: bool, + dcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -39,6 +40,7 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, + dcp_world_size=dcp_world_size, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) @@ -197,9 +199,14 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, False, - enable_kv_cache_events) + use_eagle: bool, enable_kv_cache_events: bool, + dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -225,12 +232,19 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if dcp_world_size > 1: + self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group") @@ -246,6 +260,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, + dcp_world_size=self.dcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -261,9 +276,14 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool): - super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + enable_kv_cache_events: bool, dcp_world_size: int): + super().__init__(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -394,17 +414,27 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): return hit_blocks, hit_length -def get_kv_cache_coordinator( - kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool) -> KVCacheCoordinator: +def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, + max_model_len: int, use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, + max_model_len, use_eagle, - enable_kv_cache_events) + enable_kv_cache_events, + dcp_world_size=dcp_world_size) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - enable_kv_cache_events) - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, enable_kv_cache_events) + return UnitaryKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) + return HybridKVCacheCoordinator(kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 87a11fe58a04..3a0fbb5e5c41 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -91,6 +91,7 @@ class KVCacheManager: use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, + dcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -109,12 +110,20 @@ class KVCacheManager: self.block_size = kv_cache_config.kv_cache_groups[ 0].kv_cache_spec.block_size + if dcp_world_size > 1: + assert len(kv_cache_config.kv_cache_groups) == 1 + # Note(hc): need revisit. When both DCP and any future + # PCP are enabled, the block_size may need to be scaled + # by a factor of dcp_size × pcp_size? + self.block_size *= dcp_world_size + self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, + dcp_world_size=dcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 248ad9cda7c2..aff1183e499a 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -846,6 +846,12 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, ) num_tokens = num_blocks * vllm_config.cache_config.block_size + if vllm_config.parallel_config.decode_context_parallel_size > 1: + num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + logger.info( + "Multiplying the GPU KV cache size by the dcp_world_size %d.", + vllm_config.parallel_config.decode_context_parallel_size) + num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8322fa7335b6..31f7e9c70f8b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -100,6 +100,15 @@ class Scheduler(SchedulerInterface): self.block_size = self.cache_config.block_size + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + # req_id -> Request self.requests: dict[str, Request] = {} # Scheduling policy @@ -161,6 +170,7 @@ class Scheduler(SchedulerInterface): use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index f6affb3dab66..8159349e4675 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -25,6 +25,7 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, + dcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,8 +34,10 @@ class SingleTypeKVCacheManager(ABC): block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. """ - self.block_size = kv_cache_spec.block_size + self.dcp_world_size = dcp_world_size + if self.dcp_world_size > 1: + self.block_size *= dcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -196,6 +199,7 @@ class SingleTypeKVCacheManager(ABC): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -253,6 +257,7 @@ class FullAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -260,7 +265,10 @@ class FullAttentionManager(SingleTypeKVCacheManager): "and chunked local attention groups" computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) - max_num_blocks = max_length // kv_cache_spec.block_size + block_size = kv_cache_spec.block_size + if dcp_world_size > 1: + block_size *= dcp_world_size + max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are @@ -310,9 +318,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups") + assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -408,6 +418,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -445,6 +456,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): "chunked local attention groups") assert use_eagle is False, ("Hybrid KV cache is not supported for " + "eagle + chunked local attention.") + assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = (max_length // @@ -525,10 +537,12 @@ class MambaManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, MambaSpec), ("MambaManager can only be used for mamba groups") + assert dcp_world_size == 1, "DCP not support mamba now." # Prefix caching is not supported for mamba now. Always return empty # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( @@ -583,6 +597,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a3e4d393e4d2..6467fcfe40ae 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -86,6 +86,12 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod @@ -162,6 +168,8 @@ class SlidingWindowSpec(AttentionSpec): assert not self.use_mla, "MLA is not supported for sliding window" def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ + "DCP not support sliding window." max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 6ab5ce2748a4..c5902595a496 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,6 +4,7 @@ import numpy as np import torch +from vllm.distributed import get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv @@ -50,6 +51,13 @@ class BlockTable: self.slot_mapping = torch.zeros(self.max_num_batched_tokens, dtype=torch.int64, device=self.device) + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 def append_row( self, @@ -89,13 +97,36 @@ class BlockTable: # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) - block_numbers = self.block_table_np.ravel()[block_table_indices] - block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:req_indices.shape[0]]) + if self.dcp_world_size > 1: + # Note(hc): The DCP implement store kvcache with a interleave + # style, the kvcache for the token whose token_idx is i is + # always stored on the GPU whose dcp_rank equals i % cp_world_size: + + # Use a "virtual block" which equals to world_size * block_size + # for block_table_indices calculation. + virtual_block_size = self.block_size * self.dcp_world_size + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size) + block_numbers = self.block_table_np.ravel()[block_table_indices] + # Use virtual_block_size for mask calculation, which marks local + # tokens. + virtual_block_offsets = positions % virtual_block_size + mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + # Calcuate local block_offsets + block_offsets = virtual_block_offsets // self.dcp_world_size + # Calcuate slot_mapping + slot_mapping = block_numbers * self.block_size + block_offsets + # Write final slots, use -1 for not-local + self.slot_mapping_np[:req_indices.shape[0]] = np.where( + mask, slot_mapping, -1) + else: + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // self.block_size) + block_numbers = self.block_table_np.ravel()[block_table_indices] + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) def commit_block_table(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], @@ -128,9 +159,19 @@ class MultiGroupBlockTable: def __init__(self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int]) -> None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + try: + dcp_world_size = get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + dcp_world_size = 1 + self.block_tables = [ - BlockTable(block_size, max_num_reqs, cdiv(max_model_len, - block_size), + BlockTable(block_size, max_num_reqs, + cdiv(max_model_len, block_size * dcp_world_size), max_num_batched_tokens, pin_memory, device) for block_size in block_sizes ] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5bee2dff9832..ba909f5e81b4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -56,6 +56,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.mla.flashmla import FlashMLABackend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, @@ -187,6 +188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_config.is_multimodal_raw_input_only_model) self.max_model_len = model_config.max_model_len + self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -428,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return if self.reorder_batch_threshold is not None: + if self.dcp_world_size > 1: + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, @@ -3305,6 +3310,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): get_kv_transfer_group().set_host_xfer_buffer_ops( copy_kv_blocks) + if self.dcp_world_size > 1: + assert self.attn_groups[0][0].backend is FlashMLABackend, ( + "DCP only support flashmla now." + "For a mla backend want to enable DCP, it is mandatory that the" + "corresponding decode attn kernel return the softmax lse.") + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 99c805a3e949..6a3bc5d46df2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -616,7 +616,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 08bb4e7c9e47..b4a67e2899d0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -539,8 +539,10 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, current_platform.dist_backend) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.decode_context_parallel_size) ensure_kv_transfer_initialized(vllm_config)