diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e9ec96835f277..77bc1eac16806 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -776,7 +776,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = seq_lens.cpu() dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + dcp_local_seq_lens_cpu = ( + dcp_local_seq_lens.cpu() if dcp_local_seq_lens is not None else None + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -989,6 +993,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): dcp_tot_seq_lens_device = None if self.dcp_world_size > 1: dcp_tot_seq_lens_device = seq_lens[:num_decodes] + if dcp_local_seq_lens_cpu is None: + dcp_local_seq_lens_cpu = get_dcp_local_seq_lens( + seq_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_local_block_size, + ) + if dcp_local_seq_lens is None: + assert dcp_local_seq_lens_cpu is not None + dcp_local_seq_lens = dcp_local_seq_lens_cpu.to( + seq_lens.device, non_blocking=True + ) + seq_lens_cpu = dcp_local_seq_lens_cpu seq_lens = dcp_local_seq_lens # After DCP distribution, the maximum number of tokens for any rank is