From 48fc8b1e595766af9c91edfc1de43f3a352575eb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 19 Nov 2025 10:04:07 -0500 Subject: [PATCH] [BugFix] Fix async-scheduling + FlashAttn MLA (#28990) Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 15 +++++++++------ vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- vllm/v1/attention/backends/utils.py | 1 + vllm/v1/worker/gpu_model_runner.py | 10 +++++++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2ccdd1f143ce..e328049b53c7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -755,6 +755,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -944,18 +945,20 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = None if num_decodes > 0: + dcp_tot_seq_lens_device = None + if self.dcp_world_size > 1: + dcp_tot_seq_lens_device = seq_lens[:num_decodes] + seq_lens_cpu = dcp_local_seq_lens_cpu + seq_lens = dcp_local_seq_lens + decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=dcp_local_seq_lens[:num_decodes] - if self.dcp_world_size > 1 and dcp_local_seq_lens is not None - else seq_lens[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, - dcp_tot_seq_lens_device=seq_lens[:num_decodes] - if self.dcp_world_size > 1 - else None, + dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 7794e89cc0a9..12639edc8b9a 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -173,7 +173,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() - max_seq_len = seq_lens_device.max().item() + max_seq_len = seq_lens_cpu.max().item() # For Flash Attention MLA + full cudagraph max_num_splits = 0 diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 578153cda786..0dd189633129 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -92,6 +92,7 @@ class CommonAttentionMetadata: encoder_seq_lens: np.ndarray | None = None dcp_local_seq_lens: torch.Tensor | None = None + dcp_local_seq_lens_cpu: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 506118d2d762..3b00085b6bb9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1451,9 +1451,12 @@ class GPUModelRunner( num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs ] - dcp_local_seq_lens = ( - self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None - ) + + dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None + if self.dcp_world_size > 1: + dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs] + dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs] + spec_decode_common_attn_metadata = None if for_cudagraph_capture: @@ -1521,6 +1524,7 @@ class GPUModelRunner( causal=True, encoder_seq_lens=encoder_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens, + dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu, ) if self.speculative_config and spec_decode_common_attn_metadata is None: