fix correctness issue with full-cudagraphs + attn splitting

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-24 22:47:42 +00:00
parent 96c0c4ea66
commit 97dbafaad6

View File

@ -485,14 +485,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[req_slice]
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
device, non_blocking=True).long()
# block_table_tensor = block_table.get_device_tensor()[:num_reqs]
# block_table.slot_mapping[:num_actual_tokens].copy_(
# block_table.slot_mapping_cpu[:num_actual_tokens],
# non_blocking=True)
# block_table.slot_mapping[num_actual_tokens:].fill_(-1)
# slot_mapping = block_table.slot_mapping[:num_actual_tokens]
block_table.slot_mapping[token_slice].copy_(
block_table.slot_mapping_cpu[token_slice],
non_blocking=True)
block_table.slot_mapping[token_slice.stop:].fill_(-1)
slot_mapping = block_table.slot_mapping[token_slice]
query_start_loc = slice_query_start_locs(
common_attn_metadata.query_start_loc, req_slice)