diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 70d3471a47259..5fc3a1517b690 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -237,6 +237,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): device="cpu", pin_memory=pin_memory) self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() + self.paged_kv_indptr_buffer = torch.zeros_like( + self.paged_kv_indptr_cpu, pin_memory=pin_memory) self.paged_kv_indices_cpu = torch.zeros(max_num_pages, dtype=torch.int32, device="cpu", @@ -361,12 +363,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=np.int32, out=self.paged_kv_indptr_np[1:num_reqs + 1], ) + # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified + # after this line (e.g., for cuda graphs), we need to copy the data to + # self.paged_kv_indptr_buffer to avoid race condition. + self.paged_kv_indptr_buffer[:num_reqs + + 1] = (self.paged_kv_indptr_cpu[:num_reqs + + 1]) paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] - paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1], + paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], non_blocking=True) # write self.paged_kv_indices inplace - num_actual_pages = num_blocks_np.sum().item() + num_actual_pages = self.paged_kv_indptr_np[num_reqs] paged_kv_indices = self.paged_kv_indices[:num_actual_pages] _copy_page_indices_kernel[(num_reqs, )]( paged_kv_indices,