[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)

Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Jingchun Gao 2025-11-14 19:24:10 +08:00 committed by GitHub
parent 41b92f7d38
commit 4516d44b7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 331 additions and 51 deletions

View File

@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str | None = None
@dataclass
@ -58,6 +59,7 @@ class CPTestSettings:
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str | None = None,
):
parallel_setups = []
for eager_mode_val in [False]:
@ -79,7 +81,9 @@ class CPTestSettings:
distributed_backends=["mp"],
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)
@ -117,7 +121,7 @@ def _compare_cp_with_tp(
chunked_prefill,
) = parallel_setup
multi_node_only, load_format = test_options
multi_node_only, load_format, attn_backend = test_options
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
@ -177,6 +181,13 @@ def _compare_cp_with_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if not attn_backend:
cp_env = tp_env = {}
else:
cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}
cp_args = [
*common_args,
"--tensor-parallel-size",
@ -205,6 +216,8 @@ def _compare_cp_with_tp(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
)

View File

@ -1183,6 +1183,14 @@ class ModelConfig:
f"but got {decode_context_parallel_size}"
)
num_q_per_kv = total_num_attention_heads // total_num_kv_heads
assert num_q_per_kv % decode_context_parallel_size == 0, (
f"Total number of q per kv attn heads ({num_q_per_kv})"
" must be divisible by dcp world size when enable "
"decode context parallel for GQA "
f"({parallel_config.decode_context_parallel_size})."
)
def get_sliding_window(self) -> int | None:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)

View File

@ -259,6 +259,7 @@ def use_trtllm_attention(
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
@ -272,6 +273,14 @@ def use_trtllm_attention(
if force_use_trtllm is not None and not force_use_trtllm:
return False
# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result "
"does not support DCP, reverting to FlashInfer"
)
return False
# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:

View File

@ -10,6 +10,7 @@ import torch
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper,
)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
@ -24,8 +25,11 @@ from vllm.attention.backends.abstract import (
AttentionType,
MultipleOf,
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@ -50,6 +54,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
KVCacheLayoutType,
get_dcp_local_seq_lens,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters,
@ -160,6 +165,113 @@ def trtllm_prefill_attn_kvfp8_dequant(
return mock_kv_cache, mock_block_table
class BatchDCPPrefillWrapper:
def __init__(
self,
workspace_buffer: torch.Tensor | None = None,
):
self._context = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, get_kv_cache_layout()
)
self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, get_kv_cache_layout()
)
def plan(
self,
qo_indptr_cpu: torch.Tensor,
paged_kv_indptr_cpu: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len_cpu: torch.Tensor,
prefill_start: int,
page_size: int,
num_qo_heads: int,
dcp_world_size: int,
num_kv_heads: int,
head_dim: int,
sm_scale: float,
window_left: int,
logits_soft_cap: float | None,
q_data_type: torch.dtype,
kv_cache_dtype: torch.dtype,
prefill_fixed_split_size: int,
disable_split_kv: bool,
):
"""Plan the prefill operation with given parameters."""
self._context.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
num_qo_heads * dcp_world_size,
num_kv_heads,
head_dim,
page_size,
causal=False, # This is context run
sm_scale=sm_scale,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
kv_data_type=kv_cache_dtype,
fixed_split_size=prefill_fixed_split_size,
disable_split_kv=disable_split_kv,
)
self._new_tokens.plan(
qo_indptr=qo_indptr_cpu,
kv_indptr=qo_indptr_cpu,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim,
head_dim_vo=head_dim,
causal=True, # This is newtokens run
sm_scale=sm_scale,
window_left=window_left,
logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
)
def run(
self,
layer: torch.nn.Module,
prefill_query: torch.Tensor,
kv_cache_permute: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
):
prefill_query_across_dcp = get_dcp_group().all_gather(
prefill_query.contiguous(), dim=1
)
output_context_tmp, lse_context_tmp = self._context.run(
prefill_query_across_dcp,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
return_lse=True,
)
output_context, lse_context = cp_lse_ag_out_rs(
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
)
lse_context = lse_context.transpose(0, 1).contiguous()
output_query, lse_query = self._new_tokens.run(
prefill_query,
key,
value,
return_lse=True,
)
lse_query = lse_query.transpose(0, 1).contiguous()
merge_attn_states(
out,
output_context,
lse_context,
output_query,
lse_query,
)
return out
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
@ -281,7 +393,9 @@ class FlashInferMetadata:
# For cascade attention (CPU for planning).
use_cascade: bool
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
prefill_wrapper: (
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
) = None
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
@ -303,7 +417,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._prefill_wrapper: (
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
) = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
if vllm_is_batch_invariant():
@ -341,9 +457,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.compilation_config.max_cudagraph_capture_size,
)
self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.dcp_kv_cache_interleave_size = (
vllm_config.parallel_config.dcp_kv_cache_interleave_size
)
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1
self.num_qo_heads = (
self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
* self.dcp_world_size
)
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
self.head_dim = self.kv_cache_spec.head_size
self.page_size = self.kv_cache_spec.block_size
@ -455,11 +585,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return self._workspace_buffer
def _get_prefill_wrapper(self):
def _get_prefill_wrapper(
self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), get_kv_cache_layout()
)
if self.dcp_world_size > 1:
self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(),
)
else:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), get_kv_cache_layout()
)
assert self._prefill_wrapper is not None
return self._prefill_wrapper
def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
@ -526,9 +664,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
block_table_tensor = common_attn_metadata.block_table_tensor
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
if self.dcp_world_size > 1:
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
seq_lens_np = seq_lens_cpu.numpy()
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0
@ -589,7 +747,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
paged_kv_last_page_len_np == 0,
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size,
paged_kv_last_page_len_np,
)
@ -600,13 +758,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.dcp_world_size,
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
decode_use_trtllm = self.use_trtllm_decode_attention
decode_use_trtllm = (
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)
if not (prefill_use_trtllm and decode_use_trtllm):
if self.has_sinks:
@ -651,7 +812,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
use_cascade=use_cascade,
)
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
@ -703,24 +863,52 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
if self.dcp_world_size > 1:
assert isinstance(
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu=qo_indptr_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
prefill_start=prefill_start,
page_size=self.page_size,
num_qo_heads=self.num_qo_heads,
dcp_world_size=self.dcp_world_size,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_cache_dtype=self.kv_cache_dtype,
prefill_fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else:
assert isinstance(
attn_metadata.prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
self.device, non_blocking=True
@ -770,7 +958,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indices,
self.paged_kv_last_page_len_cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens],
self.num_qo_heads,
self.num_qo_heads * self.dcp_world_size,
self.num_kv_heads,
self.head_dim,
self.page_size,
@ -797,6 +985,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
class FlashInferImpl(AttentionImpl):
can_return_lse_for_decode: bool = True
def __init__(
self,
num_heads: int,
@ -989,6 +1179,8 @@ class FlashInferImpl(AttentionImpl):
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
output_padded = output
output = output[:num_actual_tokens]
@ -1015,17 +1207,46 @@ class FlashInferImpl(AttentionImpl):
assert prefill_wrapper is not None
if not attn_metadata.prefill_use_trtllm:
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == self.window_left
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if self.dcp_world_size > 1:
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
assert prefill_wrapper._context._window_left == self.window_left
assert prefill_wrapper._context._logits_soft_cap == (
self.logits_soft_cap or 0.0
)
assert prefill_wrapper._context._sm_scale == self.scale
assert not prefill_wrapper._context._causal
assert prefill_wrapper._new_tokens._window_left == self.window_left
assert prefill_wrapper._new_tokens._logits_soft_cap == (
self.logits_soft_cap or 0.0
)
assert prefill_wrapper._new_tokens._sm_scale == self.scale
assert prefill_wrapper._new_tokens._causal
prefill_wrapper.run(
layer,
prefill_query,
kv_cache_permute,
key[num_decode_tokens:],
value[num_decode_tokens:],
out=output[num_decode_tokens:],
)
else:
assert isinstance(
prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
)
assert prefill_wrapper._window_left == self.window_left
assert prefill_wrapper._logits_soft_cap == (
self.logits_soft_cap or 0.0
)
assert prefill_wrapper._sm_scale == self.scale
assert prefill_wrapper._causal
prefill_wrapper.run(
prefill_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
else:
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
@ -1101,13 +1322,37 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
if self.dcp_world_size > 1:
decode_query = get_dcp_group().all_gather(
decode_query.contiguous(), dim=-2
)
output_tmp = torch.empty_like(decode_query)
lse = torch.empty(
(decode_query.size(0), decode_query.size(1)),
dtype=torch.float32,
device=decode_query.device,
)
decode_wrapper.run(
decode_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output_tmp,
lse=lse,
return_lse=True,
)
output[:num_decode_tokens] = cp_lse_ag_out_rs(
output_tmp, lse, get_dcp_group()
)
else:
decode_wrapper.run(
decode_query,
kv_cache_permute,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
else:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()

View File

@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
get_pp_group,
@ -726,6 +727,8 @@ class WorkerProc:
pp_rank = get_pp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
dcp_rank = get_dcp_group().rank_in_group
process_name = "Worker"
if dp_size > 1:
process_name += f"_DP{dp_rank}"
@ -733,6 +736,8 @@ class WorkerProc:
process_name += f"_PP{pp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:
process_name += f"_DCP{dcp_rank}"
if enable_ep:
ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}"