From 954f7305a106058815bd7e47f5b9d585d8764c05 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Thu, 1 Aug 2024 18:44:16 -0700 Subject: [PATCH] [Kernel] Fix input for flashinfer prefill wrapper. (#7008) --- vllm/attention/backends/flashinfer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ccf8ab03a621..91abaab78dcb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -133,13 +133,20 @@ class FlashInferMetadata(AttentionMetadata): return assert self.prefill_wrapper is not None + assert self.query_start_loc is not None assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # The prefill stage does not read kv cache. + # Both paged_kv_indices and paged_kv_last_page_len are empty. + # paged_kv_indptr is a zero tensor with size batch_size + 1. + self.paged_kv_indptr = torch.zeros(batch_size + 1, + device=self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr,