diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fa9e94ad189ac..d82afa5b630fd 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -485,14 +485,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): device = self.runner.device block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[req_slice] - slot_mapping = block_table.slot_mapping_cpu[token_slice].to( - device, non_blocking=True).long() - # block_table_tensor = block_table.get_device_tensor()[:num_reqs] - # block_table.slot_mapping[:num_actual_tokens].copy_( - # block_table.slot_mapping_cpu[:num_actual_tokens], - # non_blocking=True) - # block_table.slot_mapping[num_actual_tokens:].fill_(-1) - # slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table.slot_mapping[token_slice].copy_( + block_table.slot_mapping_cpu[token_slice], + non_blocking=True) + block_table.slot_mapping[token_slice.stop:].fill_(-1) + slot_mapping = block_table.slot_mapping[token_slice] query_start_loc = slice_query_start_locs( common_attn_metadata.query_start_loc, req_slice)