mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 08:53:33 +08:00
[DCP] Support dcp kv_cache interleave size > 1 (#26696)
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: Qiu <qiuchunshuo@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
parent
47604137a2
commit
2108a571d7
@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
|
|||||||
tp_size: int
|
tp_size: int
|
||||||
pp_size: int
|
pp_size: int
|
||||||
dcp_size: int
|
dcp_size: int
|
||||||
|
dcp_kv_cache_interleave_size: int
|
||||||
eager_mode: bool
|
eager_mode: bool
|
||||||
chunked_prefill: bool
|
chunked_prefill: bool
|
||||||
|
|
||||||
@ -52,6 +53,7 @@ class CPTestSettings:
|
|||||||
tp_base: int = 4,
|
tp_base: int = 4,
|
||||||
pp_base: int = 1,
|
pp_base: int = 1,
|
||||||
dcp_base: int = 1,
|
dcp_base: int = 1,
|
||||||
|
dcp_kv_cache_interleave_size: int = 1,
|
||||||
multi_node_only: bool = False,
|
multi_node_only: bool = False,
|
||||||
runner: RunnerOption = "auto",
|
runner: RunnerOption = "auto",
|
||||||
load_format: str | None = None,
|
load_format: str | None = None,
|
||||||
@ -66,6 +68,7 @@ class CPTestSettings:
|
|||||||
tp_size=tp_base,
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
dcp_size=int(dcp_multiplier * tp_base),
|
dcp_size=int(dcp_multiplier * tp_base),
|
||||||
|
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val,
|
chunked_prefill=chunked_prefill_val,
|
||||||
)
|
)
|
||||||
@ -108,6 +111,7 @@ def _compare_cp_with_tp(
|
|||||||
tp_size,
|
tp_size,
|
||||||
pp_size,
|
pp_size,
|
||||||
dcp_size,
|
dcp_size,
|
||||||
|
dcp_kv_cache_interleave_size,
|
||||||
eager_mode,
|
eager_mode,
|
||||||
chunked_prefill,
|
chunked_prefill,
|
||||||
) = parallel_setup
|
) = parallel_setup
|
||||||
@ -180,6 +184,8 @@ def _compare_cp_with_tp(
|
|||||||
str(pp_size),
|
str(pp_size),
|
||||||
"--decode-context-parallel-size",
|
"--decode-context-parallel-size",
|
||||||
str(dcp_size),
|
str(dcp_size),
|
||||||
|
"--dcp-kv-cache-interleave-size",
|
||||||
|
str(dcp_kv_cache_interleave_size),
|
||||||
"--distributed-executor-backend",
|
"--distributed-executor-backend",
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
]
|
]
|
||||||
@ -207,6 +213,7 @@ CP_TEXT_GENERATION_MODELS = {
|
|||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
CPTestSettings.detailed(tp_base=2),
|
CPTestSettings.detailed(tp_base=2),
|
||||||
|
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
|
||||||
],
|
],
|
||||||
"bigcode/gpt_bigcode-santacoder": [
|
"bigcode/gpt_bigcode-santacoder": [
|
||||||
CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
|
|||||||
@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
|
|||||||
max_num_reqs = 10
|
max_num_reqs = 10
|
||||||
max_num_blocks_per_req = 20
|
max_num_blocks_per_req = 20
|
||||||
max_num_batched_tokens = 512
|
max_num_batched_tokens = 512
|
||||||
|
dcp_kv_cache_interleave_size = 8
|
||||||
|
|
||||||
block_table = BlockTable(
|
block_table = BlockTable(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
device=torch.device(DEVICE),
|
device=torch.device(DEVICE),
|
||||||
kernel_block_size=kernel_block_sizes[0],
|
kernel_block_size=kernel_block_sizes[0],
|
||||||
|
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify hybrid block configuration
|
# Verify hybrid block configuration
|
||||||
|
|||||||
@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
|
|||||||
lse = tl.load(lses_ptr + lse_offsets)
|
lse = tl.load(lses_ptr + lse_offsets)
|
||||||
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
|
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
|
||||||
lse_max = tl.max(lse, axis=0)
|
lse_max = tl.max(lse, axis=0)
|
||||||
|
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
|
||||||
lse -= lse_max
|
lse -= lse_max
|
||||||
lse_exp = tl.exp(lse)
|
lse_exp = tl.exp(lse)
|
||||||
lse_acc = tl.sum(lse_exp, axis=0)
|
lse_acc = tl.sum(lse_exp, axis=0)
|
||||||
|
|||||||
@ -227,6 +227,17 @@ class ParallelConfig:
|
|||||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||||
needs to be divisible by dcp_size."""
|
needs to be divisible by dcp_size."""
|
||||||
|
|
||||||
|
dcp_kv_cache_interleave_size: int = 1
|
||||||
|
"""Interleave size of kv_cache storage while using dcp or cp > 1,
|
||||||
|
store interleave_size tokens on (d)cp i,
|
||||||
|
then store next interleave_size tokens on (d)cp i+1.
|
||||||
|
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
|
||||||
|
Interleave_size=block_size: block-level align, first fill the block on first rank,
|
||||||
|
token is stored on rank i+1 block j after rank i block j is full.
|
||||||
|
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
|
||||||
|
Block_size should be divisible by dcp_kv_cache_interleave_size.
|
||||||
|
"""
|
||||||
|
|
||||||
_api_process_count: int = Field(default=1, gt=0)
|
_api_process_count: int = Field(default=1, gt=0)
|
||||||
"""
|
"""
|
||||||
The number of API processes initialized.
|
The number of API processes initialized.
|
||||||
|
|||||||
@ -608,6 +608,23 @@ class VllmConfig:
|
|||||||
)
|
)
|
||||||
current_platform.check_and_update_config(self)
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
<= self.cache_config.block_size
|
||||||
|
and self.cache_config.block_size
|
||||||
|
% self.parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
== 0
|
||||||
|
), (
|
||||||
|
f"Block_size({self.cache_config.block_size}) should be "
|
||||||
|
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
|
||||||
|
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.parallel_config.dcp_kv_cache_interleave_size == 1
|
||||||
|
or self.speculative_config is None
|
||||||
|
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
|
||||||
|
|
||||||
# Do this after all the updates to compilation_config.mode
|
# Do this after all the updates to compilation_config.mode
|
||||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|||||||
@ -385,6 +385,7 @@ class EngineArgs:
|
|||||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||||
|
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||||
data_parallel_rank: int | None = None
|
data_parallel_rank: int | None = None
|
||||||
data_parallel_start_rank: int | None = None
|
data_parallel_start_rank: int | None = None
|
||||||
@ -750,6 +751,10 @@ class EngineArgs:
|
|||||||
"-dcp",
|
"-dcp",
|
||||||
**parallel_kwargs["decode_context_parallel_size"],
|
**parallel_kwargs["decode_context_parallel_size"],
|
||||||
)
|
)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--dcp-kv-cache-interleave-size",
|
||||||
|
**parallel_kwargs["dcp_kv_cache_interleave_size"],
|
||||||
|
)
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
|
||||||
)
|
)
|
||||||
@ -1518,6 +1523,7 @@ class EngineArgs:
|
|||||||
worker_cls=self.worker_cls,
|
worker_cls=self.worker_cls,
|
||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||||
|
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
|
||||||
_api_process_count=self._api_process_count,
|
_api_process_count=self._api_process_count,
|
||||||
_api_process_rank=self._api_process_rank,
|
_api_process_rank=self._api_process_rank,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
get_dcp_local_seq_lens,
|
||||||
get_kv_cache_layout,
|
get_kv_cache_layout,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@ -238,6 +239,10 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
|
|
||||||
|
self.dcp_kv_cache_interleave_size = (
|
||||||
|
self.parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
)
|
||||||
|
|
||||||
self.use_full_cuda_graph = (
|
self.use_full_cuda_graph = (
|
||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
)
|
)
|
||||||
@ -352,8 +357,12 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||||
)
|
)
|
||||||
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
||||||
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
|
|
||||||
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
|
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
|
||||||
|
dcp_context_kv_lens_cpu,
|
||||||
|
self.dcp_world_size,
|
||||||
|
self.dcp_rank,
|
||||||
|
self.dcp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||||
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||||
|
|||||||
@ -225,6 +225,7 @@ from vllm.utils.math_utils import cdiv, round_down
|
|||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
get_dcp_local_seq_lens,
|
||||||
get_per_layer_parameters,
|
get_per_layer_parameters,
|
||||||
infer_global_hyperparameters,
|
infer_global_hyperparameters,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
@ -361,10 +362,9 @@ class MLACommonPrefillMetadata:
|
|||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
|
|
||||||
# for mla DCP
|
# for mla DCP
|
||||||
cp_chunk_seq_lens: list[list[int]] | None = None
|
padded_local_chunk_seq_lens: list[list[int]] | None = None
|
||||||
origin_context_lens: list[int] | None = None
|
local_context_lens_allranks: list[list[int]] | None = None
|
||||||
cp_cu_seq_lens: torch.Tensor | None = None
|
padded_local_cu_seq_lens: torch.Tensor | None = None
|
||||||
chunk_size: int | None = None
|
|
||||||
cu_seq_lens_lst: list[list[int]] | None = None
|
cu_seq_lens_lst: list[list[int]] | None = None
|
||||||
|
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@ -568,6 +568,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
|
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
|
||||||
|
|
||||||
# Don't try to access the runner on AMD
|
# Don't try to access the runner on AMD
|
||||||
if self.aot_schedule:
|
if self.aot_schedule:
|
||||||
@ -794,15 +796,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note(hc): update seq_lens of decode reqs under DCP.
|
|
||||||
if self.dcp_world_size > 1:
|
|
||||||
assert dcp_local_seq_lens is not None
|
|
||||||
dcp_local_seq_lens[:num_decodes] = seq_lens[
|
|
||||||
:num_decodes
|
|
||||||
] // self.dcp_world_size + (
|
|
||||||
self.dcp_rank < seq_lens[:num_decodes] % self.dcp_world_size
|
|
||||||
)
|
|
||||||
|
|
||||||
assert num_decodes + num_prefills == num_reqs
|
assert num_decodes + num_prefills == num_reqs
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||||
|
|
||||||
@ -811,11 +804,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
reqs_start = num_decodes # prefill_start
|
reqs_start = num_decodes # prefill_start
|
||||||
|
|
||||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||||
# Note(hc): The context lengths in the perspective of dcp rank0.
|
|
||||||
cp_context_lens_cpu = torch.ceil(
|
|
||||||
context_lens_cpu.float() / self.dcp_world_size
|
|
||||||
).int()
|
|
||||||
origin_context_lens = context_lens_cpu.tolist()
|
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
prefill_query_start_loc = (
|
prefill_query_start_loc = (
|
||||||
@ -871,32 +859,56 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
|
local_context_lens_allranks = get_dcp_local_seq_lens(
|
||||||
|
context_lens_cpu,
|
||||||
|
self.dcp_world_size,
|
||||||
|
None,
|
||||||
|
self.dcp_local_block_size,
|
||||||
|
)
|
||||||
|
# Note(qcs): The max local context lengths
|
||||||
|
# padded to `dcp_local_block_size`.
|
||||||
|
padded_local_context_lens_cpu = (
|
||||||
|
cdiv(
|
||||||
|
context_lens_cpu,
|
||||||
|
self.dcp_virtual_block_size,
|
||||||
|
)
|
||||||
|
* self.dcp_local_block_size
|
||||||
|
)
|
||||||
# Note(hc): The above max_context_chunk already enforces
|
# Note(hc): The above max_context_chunk already enforces
|
||||||
# block_size alignment, DCP just need the block_size can
|
# block_size alignment, DCP just need the block_size can
|
||||||
# be divisible by dcp_world_size, because DCP use
|
# be divisible by dcp_world_size, because DCP use
|
||||||
# cp_gather_cache which not require `cp_chunk_starts`
|
# cp_gather_cache which not require `cp_chunk_starts`
|
||||||
# aligned to page_size.
|
# aligned to page_size.
|
||||||
assert max_context_chunk % self.dcp_world_size == 0
|
assert max_context_chunk % self.dcp_world_size == 0
|
||||||
cp_max_context_chunk = max_context_chunk // self.dcp_world_size
|
padded_local_max_context_chunk_across_ranks = (
|
||||||
cp_chunk_starts = (
|
cdiv(
|
||||||
|
max_context_chunk,
|
||||||
|
self.dcp_virtual_block_size,
|
||||||
|
)
|
||||||
|
* self.dcp_local_block_size
|
||||||
|
)
|
||||||
|
local_chunk_starts = (
|
||||||
torch.arange(num_chunks, dtype=torch.int32)
|
torch.arange(num_chunks, dtype=torch.int32)
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.expand(-1, num_prefills)
|
.expand(-1, num_prefills)
|
||||||
* cp_max_context_chunk
|
* padded_local_max_context_chunk_across_ranks
|
||||||
)
|
)
|
||||||
cp_chunk_ends = torch.min(
|
local_chunk_ends = torch.min(
|
||||||
cp_context_lens_cpu.unsqueeze(0),
|
padded_local_context_lens_cpu.unsqueeze(0),
|
||||||
cp_chunk_starts + cp_max_context_chunk,
|
local_chunk_starts
|
||||||
|
+ padded_local_max_context_chunk_across_ranks,
|
||||||
)
|
)
|
||||||
cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0)
|
padded_local_chunk_seq_lens = (
|
||||||
|
local_chunk_ends - local_chunk_starts
|
||||||
|
).clamp(min=0)
|
||||||
|
|
||||||
cp_cu_seq_lens_cpu = torch.zeros(
|
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
|
||||||
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
|
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
|
||||||
)
|
)
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
cp_chunk_seq_lens,
|
padded_local_chunk_seq_lens,
|
||||||
dim=1,
|
dim=1,
|
||||||
out=cp_cu_seq_lens_cpu[:, 1:],
|
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -908,15 +920,16 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
chunked_context_metadata = chunked_context_metadata_cls(
|
chunked_context_metadata = chunked_context_metadata_cls(
|
||||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
starts=local_chunk_starts.to(device, non_blocking=True),
|
||||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
seq_lens=chunk_seq_lens,
|
seq_lens=chunk_seq_lens,
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
|
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||||
origin_context_lens=origin_context_lens,
|
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||||
cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True),
|
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
||||||
chunk_size=max_context_chunk,
|
device, non_blocking=True
|
||||||
|
),
|
||||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -998,64 +1011,52 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
def reorg_kvcache(
|
def reorg_kvcache(
|
||||||
allgatered_kv_c_normed: torch.Tensor,
|
allgatered_kv_c_normed: torch.Tensor,
|
||||||
allgatered_k_pe: torch.Tensor,
|
allgatered_k_pe: torch.Tensor,
|
||||||
cp_chunk_seq_lens_lst: list[int],
|
padded_local_chunk_seq_lens_lst: list[int],
|
||||||
origin_context_lens: list[int],
|
local_context_lens_allranks: list[list[int]],
|
||||||
cp_world_size: int,
|
|
||||||
sum_seq_len: int,
|
sum_seq_len: int,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
chunk_size: int,
|
|
||||||
chunk_idx: int,
|
|
||||||
toks: int,
|
toks: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
reorg kvcache after cp local gather to tp layout for attn kernel.
|
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
|
||||||
|
e.g.
|
||||||
|
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
|
||||||
|
T0_4, T0_5, pad, pad, T1_2, pad, ...]
|
||||||
|
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
|
||||||
|
T1_0, T1_1, T1_2, ...]
|
||||||
Args:
|
Args:
|
||||||
cp_chunk_seq_lens_lst: chunk context lengths under CP.
|
padded_local_chunk_seq_lens_lst: local chunk context lengths
|
||||||
origin_context_lens: origin full context lengths under CP.
|
under current CP rank.
|
||||||
cp_world_size: CP size.
|
local_context_lens_allranks: local context lengths on each CP rank.
|
||||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||||
chunk_size: equals to max_context_chunk from
|
|
||||||
chunked_context_metadata building.
|
|
||||||
chunk_idx: chunk idx of chunked_prefill.
|
|
||||||
toks: the number of tokens for local gather cache.
|
toks: the number of tokens for local gather cache.
|
||||||
"""
|
"""
|
||||||
kv_c_segments = []
|
kv_c_segments = []
|
||||||
k_pe_segments = []
|
k_pe_segments = []
|
||||||
src_token_idx = 0
|
src_token_idx = 0
|
||||||
max_seq_len_check = 0
|
max_seq_len_check = 0
|
||||||
for cp_chunk_seq_len, origin_context_len in zip(
|
for padded_local_chunk_seq_len, local_context_lens in zip(
|
||||||
cp_chunk_seq_lens_lst, origin_context_lens
|
padded_local_chunk_seq_lens_lst, local_context_lens_allranks
|
||||||
):
|
):
|
||||||
chunk_context_len = chunk_size
|
|
||||||
if cp_chunk_seq_len != 0:
|
|
||||||
chunk_context_len = min(
|
|
||||||
chunk_context_len, origin_context_len - chunk_size * chunk_idx
|
|
||||||
)
|
|
||||||
cp_target_rank = (chunk_context_len - 1) % cp_world_size
|
|
||||||
cur_seq_len = 0
|
cur_seq_len = 0
|
||||||
for rank in range(cp_world_size):
|
for rank, local_context_len in enumerate(local_context_lens):
|
||||||
if rank > cp_target_rank and cp_chunk_seq_len:
|
if local_context_len != 0:
|
||||||
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
|
|
||||||
else:
|
|
||||||
real_cp_chunk_seq_len = cp_chunk_seq_len
|
|
||||||
if real_cp_chunk_seq_len:
|
|
||||||
kv_c_segment = allgatered_kv_c_normed[
|
kv_c_segment = allgatered_kv_c_normed[
|
||||||
rank * toks + src_token_idx : rank * toks
|
rank * toks + src_token_idx : rank * toks
|
||||||
+ src_token_idx
|
+ src_token_idx
|
||||||
+ real_cp_chunk_seq_len
|
+ local_context_len
|
||||||
]
|
]
|
||||||
k_pe_segment = allgatered_k_pe[
|
k_pe_segment = allgatered_k_pe[
|
||||||
rank * toks + src_token_idx : rank * toks
|
rank * toks + src_token_idx : rank * toks
|
||||||
+ src_token_idx
|
+ src_token_idx
|
||||||
+ real_cp_chunk_seq_len
|
+ local_context_len
|
||||||
]
|
]
|
||||||
kv_c_segments.append(kv_c_segment)
|
kv_c_segments.append(kv_c_segment)
|
||||||
k_pe_segments.append(k_pe_segment)
|
k_pe_segments.append(k_pe_segment)
|
||||||
cur_seq_len += real_cp_chunk_seq_len
|
cur_seq_len += local_context_len
|
||||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||||
src_token_idx += cp_chunk_seq_len
|
src_token_idx += padded_local_chunk_seq_len
|
||||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||||
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||||
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||||
@ -1296,6 +1297,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
get_current_vllm_config()
|
get_current_vllm_config()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.dcp_kv_cache_interleave_size: int = (
|
||||||
|
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
|
||||||
|
)
|
||||||
|
|
||||||
def _flash_attn_varlen_diff_headdims(
|
def _flash_attn_varlen_diff_headdims(
|
||||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||||
@ -1697,10 +1701,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
assert prefill_metadata.chunked_context is not None
|
assert prefill_metadata.chunked_context is not None
|
||||||
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
|
assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
|
||||||
assert prefill_metadata.chunked_context.origin_context_lens is not None
|
assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
|
||||||
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
|
assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
|
||||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
|
||||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||||
|
|
||||||
output = None
|
output = None
|
||||||
@ -1713,7 +1716,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
block_table=prefill_metadata.block_table,
|
block_table=prefill_metadata.block_table,
|
||||||
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
|
cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
|
||||||
|
i
|
||||||
|
],
|
||||||
batch_size=attn_metadata.num_prefills,
|
batch_size=attn_metadata.num_prefills,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
)
|
)
|
||||||
@ -1743,15 +1748,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
kv_c_normed, k_pe = reorg_kvcache(
|
kv_c_normed, k_pe = reorg_kvcache(
|
||||||
allgatered_kv_c_normed,
|
allgatered_kv_c_normed,
|
||||||
allgatered_k_pe,
|
allgatered_k_pe,
|
||||||
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[
|
padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
|
||||||
i
|
i
|
||||||
],
|
],
|
||||||
origin_context_lens=prefill_metadata.chunked_context.origin_context_lens,
|
local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
|
||||||
cp_world_size=dcp_world_size,
|
|
||||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
|
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
|
||||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
|
||||||
chunk_idx=i,
|
|
||||||
toks=toks,
|
toks=toks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1076,3 +1076,41 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
|||||||
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
|
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
|
||||||
|
|
||||||
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
||||||
|
|
||||||
|
|
||||||
|
def get_dcp_local_seq_lens(
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
dcp_world_size: int = 1,
|
||||||
|
dcp_rank: int | None = None,
|
||||||
|
dcp_kv_cache_interleave_size: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""While using dcp, kv_cache size stored on each rank may be different,
|
||||||
|
use this function to calculate split decode seq_lens of each dcp rank.
|
||||||
|
Only consider dcp now, we can extend the case of cp based on this.
|
||||||
|
"""
|
||||||
|
num_requests = seq_lens.size(0)
|
||||||
|
if dcp_rank is None:
|
||||||
|
rank_offsets = (
|
||||||
|
torch.arange(dcp_world_size, dtype=torch.int32)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(num_requests, 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
|
||||||
|
seq_lens_tiled = (
|
||||||
|
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
|
||||||
|
)
|
||||||
|
base = (
|
||||||
|
seq_lens_tiled
|
||||||
|
// dcp_kv_cache_interleave_size
|
||||||
|
// dcp_world_size
|
||||||
|
* dcp_kv_cache_interleave_size
|
||||||
|
)
|
||||||
|
remainder = seq_lens_tiled - base * dcp_world_size
|
||||||
|
remainder = torch.clip(
|
||||||
|
remainder - rank_offsets * dcp_kv_cache_interleave_size,
|
||||||
|
0,
|
||||||
|
dcp_kv_cache_interleave_size,
|
||||||
|
)
|
||||||
|
dcp_local_seq_lens = base + remainder
|
||||||
|
return dcp_local_seq_lens.squeeze(1)
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class BlockTable:
|
|||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
kernel_block_size: int,
|
kernel_block_size: int,
|
||||||
|
dcp_kv_cache_interleave_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -86,6 +87,7 @@ class BlockTable:
|
|||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
self.dcp_world_size = 1
|
self.dcp_world_size = 1
|
||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
|
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||||
|
|
||||||
def append_row(
|
def append_row(
|
||||||
self,
|
self,
|
||||||
@ -144,9 +146,19 @@ class BlockTable:
|
|||||||
# Use virtual_block_size for mask calculation, which marks local
|
# Use virtual_block_size for mask calculation, which marks local
|
||||||
# tokens.
|
# tokens.
|
||||||
virtual_block_offsets = positions % virtual_block_size
|
virtual_block_offsets = positions % virtual_block_size
|
||||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
mask = (
|
||||||
|
virtual_block_offsets
|
||||||
|
// self.dcp_kv_cache_interleave_size
|
||||||
|
% self.dcp_world_size
|
||||||
|
== self.dcp_rank
|
||||||
|
)
|
||||||
# Calculate local block_offsets
|
# Calculate local block_offsets
|
||||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
block_offsets = (
|
||||||
|
virtual_block_offsets
|
||||||
|
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
||||||
|
* self.dcp_kv_cache_interleave_size
|
||||||
|
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
||||||
|
)
|
||||||
# Calculate slot_mapping
|
# Calculate slot_mapping
|
||||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||||
# Write final slots, use -1 for not-local
|
# Write final slots, use -1 for not-local
|
||||||
@ -234,6 +246,7 @@ class MultiGroupBlockTable:
|
|||||||
block_sizes: list[int],
|
block_sizes: list[int],
|
||||||
kernel_block_sizes: list[int],
|
kernel_block_sizes: list[int],
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
|
dcp_kv_cache_interleave_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Note(hc): each dcp rank only store
|
# Note(hc): each dcp rank only store
|
||||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||||
@ -263,6 +276,7 @@ class MultiGroupBlockTable:
|
|||||||
pin_memory,
|
pin_memory,
|
||||||
device,
|
device,
|
||||||
kernel_block_size,
|
kernel_block_size,
|
||||||
|
dcp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -84,6 +84,7 @@ class InputBatch:
|
|||||||
is_spec_decode: bool = False,
|
is_spec_decode: bool = False,
|
||||||
is_pooling_model: bool = False,
|
is_pooling_model: bool = False,
|
||||||
num_speculative_tokens: int = 0,
|
num_speculative_tokens: int = 0,
|
||||||
|
dcp_kv_cache_interleave_size: int = 1,
|
||||||
):
|
):
|
||||||
self.is_pooling_model = is_pooling_model
|
self.is_pooling_model = is_pooling_model
|
||||||
self.is_spec_decode = is_spec_decode
|
self.is_spec_decode = is_spec_decode
|
||||||
@ -137,6 +138,7 @@ class InputBatch:
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
kernel_block_sizes=kernel_block_sizes,
|
kernel_block_sizes=kernel_block_sizes,
|
||||||
num_speculative_tokens=num_speculative_tokens,
|
num_speculative_tokens=num_speculative_tokens,
|
||||||
|
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling-related.
|
# Sampling-related.
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.distributed.eplb.eplb_state import EplbState
|
|||||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_dcp_group,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
graph_capture,
|
graph_capture,
|
||||||
@ -88,6 +89,7 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
|
get_dcp_local_seq_lens,
|
||||||
reorder_batch_to_split_decodes_and_prefills,
|
reorder_batch_to_split_decodes_and_prefills,
|
||||||
split_attn_metadata,
|
split_attn_metadata,
|
||||||
)
|
)
|
||||||
@ -275,6 +277,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.is_multimodal_pruning_enabled = False
|
self.is_multimodal_pruning_enabled = False
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
||||||
|
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||||
|
|
||||||
@ -396,6 +399,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# uses output token ids so we set this conservatively.
|
# uses output token ids so we set this conservatively.
|
||||||
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
|
||||||
is_pooling_model=self.is_pooling_model,
|
is_pooling_model=self.is_pooling_model,
|
||||||
|
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||||
@ -1307,6 +1311,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits_indices
|
logits_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# update seq_lens of decode reqs under DCP.
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
|
||||||
|
self.seq_lens.cpu[:num_reqs],
|
||||||
|
self.dcp_world_size,
|
||||||
|
self.dcp_rank,
|
||||||
|
self.parallel_config.dcp_kv_cache_interleave_size,
|
||||||
|
)
|
||||||
|
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
||||||
|
|
||||||
attn_metadata: PerLayerAttnMetadata = {}
|
attn_metadata: PerLayerAttnMetadata = {}
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user