diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f9973a89c7f2c..77bc1eac16806 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -776,7 +776,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = seq_lens.cpu() dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens + dcp_local_seq_lens_cpu = ( + dcp_local_seq_lens.cpu() if dcp_local_seq_lens is not None else None + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -997,6 +1001,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.dcp_local_block_size, ) if dcp_local_seq_lens is None: + assert dcp_local_seq_lens_cpu is not None dcp_local_seq_lens = dcp_local_seq_lens_cpu.to( seq_lens.device, non_blocking=True )