From ac80c90235098d3bacec3364990f7daa8a27da04 Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Mon, 22 Dec 2025 18:33:43 +0530 Subject: [PATCH 1/2] [Fix] Fix MLA attention crash when using DP with DCP Signed-off-by: Sachin Singh --- vllm/v1/attention/backends/mla/common.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e9ec96835f277..f9973a89c7f2c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -989,6 +989,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): dcp_tot_seq_lens_device = None if self.dcp_world_size > 1: dcp_tot_seq_lens_device = seq_lens[:num_decodes] + if dcp_local_seq_lens_cpu is None: + dcp_local_seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_local_block_size, + ) + if dcp_local_seq_lens is None: + dcp_local_seq_lens = dcp_local_seq_lens_cpu.to( + seq_lens.device, non_blocking=True + ) + seq_lens_cpu = dcp_local_seq_lens_cpu seq_lens = dcp_local_seq_lens # After DCP distribution, the maximum number of tokens for any rank is From aa416c04b66bc0be062ee6e8e8682167a611f693 Mon Sep 17 00:00:00 2001 From: Sachin Singh Date: Tue, 23 Dec 2025 16:49:36 +0530 Subject: [PATCH 2/2] Fix linting and address depcrecation warnings Signed-off-by: Sachin Singh --- vllm/v1/attention/backends/mla/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f9973a89c7f2c..77bc1eac16806 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -776,7 +776,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = seq_lens.cpu() dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + dcp_local_seq_lens_cpu = ( + dcp_local_seq_lens.cpu() if dcp_local_seq_lens is not None else None + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -997,6 +1001,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.dcp_local_block_size, ) if dcp_local_seq_lens is None: + assert dcp_local_seq_lens_cpu is not None dcp_local_seq_lens = dcp_local_seq_lens_cpu.to( seq_lens.device, non_blocking=True )