diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 3576efca591c..b16fd0d06b14 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool load_format: str | None = None + attn_backend: str | None = None @dataclass @@ -58,6 +59,7 @@ class CPTestSettings: multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, + attn_backend: str | None = None, ): parallel_setups = [] for eager_mode_val in [False]: @@ -79,7 +81,9 @@ class CPTestSettings: distributed_backends=["mp"], runner=runner, test_options=CPTestOptions( - multi_node_only=multi_node_only, load_format=load_format + multi_node_only=multi_node_only, + load_format=load_format, + attn_backend=attn_backend, ), ) @@ -117,7 +121,7 @@ def _compare_cp_with_tp( chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options + multi_node_only, load_format, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -177,6 +181,13 @@ def _compare_cp_with_tp( if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if not attn_backend: + cp_env = tp_env = {} + else: + cp_env = tp_env = { + "VLLM_ATTENTION_BACKEND": attn_backend, + } + cp_args = [ *common_args, "--tensor-parallel-size", @@ -205,6 +216,8 @@ def _compare_cp_with_tp( model_id, cp_args, tp_args, + cp_env, + tp_env, method=method, max_wait_seconds=720, ) diff --git a/vllm/config/model.py b/vllm/config/model.py index f4ed99689e5b..8ec66b6b3160 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1183,6 +1183,14 @@ class ModelConfig: f"but got {decode_context_parallel_size}" ) + num_q_per_kv = total_num_attention_heads // total_num_kv_heads + assert num_q_per_kv % decode_context_parallel_size == 0, ( + f"Total number of q per kv attn heads ({num_q_per_kv})" + " must be divisible by dcp world size when enable " + "decode context parallel for GQA " + f"({parallel_config.decode_context_parallel_size})." + ) + def get_sliding_window(self) -> int | None: """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 62af39513d65..79e5a4c30259 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -259,6 +259,7 @@ def use_trtllm_attention( num_kv_heads: int, num_tokens: int, max_seq_len: int, + dcp_world_size: int, kv_cache_dtype: str, q_dtype: torch.dtype, is_prefill: bool, @@ -272,6 +273,14 @@ def use_trtllm_attention( if force_use_trtllm is not None and not force_use_trtllm: return False + # Decode context parallel is not supported + if dcp_world_size > 1: + logger.warning_once( + "Trtllm does not support returning LSE and as a result " + "does not support DCP, reverting to FlashInfer" + ) + return False + # The platform is not supported if not supports_trtllm_attention(): if force_use_trtllm: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0b650e2e0d33..4da1637d96eb 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,6 +10,7 @@ import torch from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, MultiLevelCascadeAttentionWrapper, ) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache @@ -24,8 +25,11 @@ from vllm.attention.backends.abstract import ( AttentionType, MultipleOf, ) +from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -50,6 +54,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, KVCacheLayoutType, + get_dcp_local_seq_lens, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, @@ -160,6 +165,113 @@ def trtllm_prefill_attn_kvfp8_dequant( return mock_kv_cache, mock_block_table +class BatchDCPPrefillWrapper: + def __init__( + self, + workspace_buffer: torch.Tensor | None = None, + ): + self._context = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, get_kv_cache_layout() + ) + + def plan( + self, + qo_indptr_cpu: torch.Tensor, + paged_kv_indptr_cpu: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len_cpu: torch.Tensor, + prefill_start: int, + page_size: int, + num_qo_heads: int, + dcp_world_size: int, + num_kv_heads: int, + head_dim: int, + sm_scale: float, + window_left: int, + logits_soft_cap: float | None, + q_data_type: torch.dtype, + kv_cache_dtype: torch.dtype, + prefill_fixed_split_size: int, + disable_split_kv: bool, + ): + """Plan the prefill operation with given parameters.""" + self._context.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + num_qo_heads * dcp_world_size, + num_kv_heads, + head_dim, + page_size, + causal=False, # This is context run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + kv_data_type=kv_cache_dtype, + fixed_split_size=prefill_fixed_split_size, + disable_split_kv=disable_split_kv, + ) + self._new_tokens.plan( + qo_indptr=qo_indptr_cpu, + kv_indptr=qo_indptr_cpu, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + causal=True, # This is newtokens run + sm_scale=sm_scale, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + ) + + def run( + self, + layer: torch.nn.Module, + prefill_query: torch.Tensor, + kv_cache_permute: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + ): + prefill_query_across_dcp = get_dcp_group().all_gather( + prefill_query.contiguous(), dim=1 + ) + output_context_tmp, lse_context_tmp = self._context.run( + prefill_query_across_dcp, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output_context, lse_context = cp_lse_ag_out_rs( + output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True + ) + lse_context = lse_context.transpose(0, 1).contiguous() + + output_query, lse_query = self._new_tokens.run( + prefill_query, + key, + value, + return_lse=True, + ) + lse_query = lse_query.transpose(0, 1).contiguous() + + merge_attn_states( + out, + output_context, + lse_context, + output_query, + lse_query, + ) + return out + + class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] @@ -281,7 +393,9 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None @@ -303,7 +417,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config self._workspace_buffer = None - self._prefill_wrapper = None # Wrapper for prefill/append + self._prefill_wrapper: ( + BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None + ) = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) if vllm_is_batch_invariant(): @@ -341,9 +457,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.compilation_config.max_cudagraph_capture_size, ) - self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config + try: + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + self.dcp_kv_cache_interleave_size = ( + vllm_config.parallel_config.dcp_kv_cache_interleave_size + ) + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = 1 + + self.num_qo_heads = ( + self.model_config.get_num_attention_heads(self.vllm_config.parallel_config) + * self.dcp_world_size ) + self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size self.page_size = self.kv_cache_spec.block_size @@ -455,11 +585,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) return self._workspace_buffer - def _get_prefill_wrapper(self): + def _get_prefill_wrapper( + self, + ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: if self._prefill_wrapper is None: - self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout() - ) + if self.dcp_world_size > 1: + self._prefill_wrapper = BatchDCPPrefillWrapper( + workspace_buffer=self._get_workspace_buffer(), + ) + else: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) + assert self._prefill_wrapper is not None return self._prefill_wrapper def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): @@ -526,9 +664,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu - seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor + qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu + if self.dcp_world_size > 1: + if num_prefills > 0: + qo_indptr_prefill_cpu = ( + qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes] + ) + query_lens_prefill_cpu = ( + qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1] + ) + seq_lens_cpu[num_decodes:] = ( + seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu + ) + + seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + + seq_lens_np = seq_lens_cpu.numpy() num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 @@ -589,7 +747,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # write self.paged_kv_last_page_len_cpu inplace paged_kv_last_page_len_np = seq_lens_np % page_size self.paged_kv_last_page_len_np[:num_reqs] = np.where( - paged_kv_last_page_len_np == 0, + (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), page_size, paged_kv_last_page_len_np, ) @@ -600,13 +758,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.num_kv_heads, num_prefill_tokens, max_seq_len, + self.dcp_world_size, self.cache_dtype, self.q_data_type, is_prefill=True, has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) - decode_use_trtllm = self.use_trtllm_decode_attention + decode_use_trtllm = ( + self.use_trtllm_decode_attention and self.dcp_world_size <= 1 + ) if not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks: @@ -651,7 +812,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): use_cascade=use_cascade, ) - qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] @@ -703,24 +863,52 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + if self.dcp_world_size > 1: + assert isinstance( + attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu=qo_indptr_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, + prefill_start=prefill_start, + page_size=self.page_size, + num_qo_heads=self.num_qo_heads, + dcp_world_size=self.dcp_world_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_cache_dtype=self.kv_cache_dtype, + prefill_fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) + else: + assert isinstance( + attn_metadata.prefill_wrapper, + BatchPrefillWithPagedKVCacheWrapper, + ) + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( self.device, non_blocking=True @@ -770,7 +958,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], - self.num_qo_heads, + self.num_qo_heads * self.dcp_world_size, self.num_kv_heads, self.head_dim, self.page_size, @@ -797,6 +985,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, @@ -989,6 +1179,8 @@ class FlashInferImpl(AttentionImpl): # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] output_padded = output output = output[:num_actual_tokens] @@ -1015,17 +1207,46 @@ class FlashInferImpl(AttentionImpl): assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._causal - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) + if self.dcp_world_size > 1: + assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) + assert prefill_wrapper._context._window_left == self.window_left + assert prefill_wrapper._context._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._context._sm_scale == self.scale + assert not prefill_wrapper._context._causal + assert prefill_wrapper._new_tokens._window_left == self.window_left + assert prefill_wrapper._new_tokens._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._new_tokens._sm_scale == self.scale + assert prefill_wrapper._new_tokens._causal + + prefill_wrapper.run( + layer, + prefill_query, + kv_cache_permute, + key[num_decode_tokens:], + value[num_decode_tokens:], + out=output[num_decode_tokens:], + ) + else: + assert isinstance( + prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper + ) + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == ( + self.logits_soft_cap or 0.0 + ) + assert prefill_wrapper._sm_scale == self.scale + assert prefill_wrapper._causal + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() @@ -1101,13 +1322,37 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) + + if self.dcp_world_size > 1: + decode_query = get_dcp_group().all_gather( + decode_query.contiguous(), dim=-2 + ) + output_tmp = torch.empty_like(decode_query) + lse = torch.empty( + (decode_query.size(0), decode_query.size(1)), + dtype=torch.float32, + device=decode_query.device, + ) + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output_tmp, + lse=lse, + return_lse=True, + ) + output[:num_decode_tokens] = cp_lse_ag_out_rs( + output_tmp, lse, get_dcp_group() + ) + else: + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 1e249161c688..881e6ef40aaf 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.parallel_state import ( + get_dcp_group, get_dp_group, get_ep_group, get_pp_group, @@ -726,6 +727,8 @@ class WorkerProc: pp_rank = get_pp_group().rank_in_group tp_size = get_tp_group().world_size tp_rank = get_tp_group().rank_in_group + dcp_size = get_dcp_group().world_size + dcp_rank = get_dcp_group().rank_in_group process_name = "Worker" if dp_size > 1: process_name += f"_DP{dp_rank}" @@ -733,6 +736,8 @@ class WorkerProc: process_name += f"_PP{pp_rank}" if tp_size > 1: process_name += f"_TP{tp_rank}" + if dcp_size > 1: + process_name += f"_DCP{dcp_rank}" if enable_ep: ep_rank = get_ep_group().rank_in_group process_name += f"_EP{ep_rank}"