[BugFix][FlashInfer] Fix potential race condition for paged_kv_indptr_cpu (#23737)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-28 14:22:46 -07:00 committed by GitHub
parent 27e88cee74
commit 7ffbf27239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -237,6 +237,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory)
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
@ -361,12 +363,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=np.int32,
out=self.paged_kv_indptr_np[1:num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_buffer[:num_reqs +
1] = (self.paged_kv_indptr_cpu[:num_reqs +
1])
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1],
non_blocking=True)
# write self.paged_kv_indices inplace
num_actual_pages = num_blocks_np.sum().item()
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs, )](
paged_kv_indices,