mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:54:58 +08:00
[BugFix][FlashInfer] Fix potential race condition for paged_kv_indptr_cpu (#23737)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
27e88cee74
commit
7ffbf27239
@ -237,6 +237,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
||||||
|
self.paged_kv_indptr_buffer = torch.zeros_like(
|
||||||
|
self.paged_kv_indptr_cpu, pin_memory=pin_memory)
|
||||||
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -361,12 +363,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
out=self.paged_kv_indptr_np[1:num_reqs + 1],
|
out=self.paged_kv_indptr_np[1:num_reqs + 1],
|
||||||
)
|
)
|
||||||
|
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
|
||||||
|
# after this line (e.g., for cuda graphs), we need to copy the data to
|
||||||
|
# self.paged_kv_indptr_buffer to avoid race condition.
|
||||||
|
self.paged_kv_indptr_buffer[:num_reqs +
|
||||||
|
1] = (self.paged_kv_indptr_cpu[:num_reqs +
|
||||||
|
1])
|
||||||
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
|
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
|
||||||
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
|
paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
# write self.paged_kv_indices inplace
|
# write self.paged_kv_indices inplace
|
||||||
num_actual_pages = num_blocks_np.sum().item()
|
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
|
||||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||||
_copy_page_indices_kernel[(num_reqs, )](
|
_copy_page_indices_kernel[(num_reqs, )](
|
||||||
paged_kv_indices,
|
paged_kv_indices,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user