diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e9ec96835f277..f9973a89c7f2c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -989,6 +989,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): dcp_tot_seq_lens_device = None if self.dcp_world_size > 1: dcp_tot_seq_lens_device = seq_lens[:num_decodes] + if dcp_local_seq_lens_cpu is None: + dcp_local_seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_local_block_size, + ) + if dcp_local_seq_lens is None: + dcp_local_seq_lens = dcp_local_seq_lens_cpu.to( + seq_lens.device, non_blocking=True + ) + seq_lens_cpu = dcp_local_seq_lens_cpu seq_lens = dcp_local_seq_lens # After DCP distribution, the maximum number of tokens for any rank is