diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f948157c2b575..babb4c934420f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -39,6 +39,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, split_decodes_and_prefills) # yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.utils import CpuGpuBuffer FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -215,34 +216,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) - # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=self.device) - self.paged_kv_indices = torch.zeros( - max_num_pages, # max num pages possible - dtype=torch.int32, - device=self.device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=self.device) - # host-side buffer + # Preparing persistent buffers pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() - self.paged_kv_indices_cpu = torch.zeros(max_num_pages, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_np = ( - self.paged_kv_last_page_len_cpu.numpy()) + self.paged_kv_indptr = CpuGpuBuffer(max_num_reqs + 1, + dtype=torch.int32, + device=self.device, + pin_memory=pin_memory) + self.paged_kv_indices = CpuGpuBuffer(max_num_pages, + dtype=torch.int32, + device=self.device, + pin_memory=pin_memory) + self.paged_kv_last_page_len = CpuGpuBuffer(max_num_reqs, + dtype=torch.int32, + device=self.device, + pin_memory=pin_memory) def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -269,10 +256,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] - paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[: - batch_size] + paged_kv_indptr = self.paged_kv_indptr.gpu[:batch_size + 1] + paged_kv_indices = self.paged_kv_indices.gpu + paged_kv_last_page_len = ( + self.paged_kv_last_page_len.gpu[:batch_size]) else: paged_kv_indptr = None paged_kv_indices = None @@ -355,15 +342,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): np.cumsum( num_blocks_np, dtype=np.int32, - out=self.paged_kv_indptr_np[1:num_reqs + 1], + out=self.paged_kv_indptr.np[1:num_reqs + 1], ) - paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] - paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1], - non_blocking=True) + paged_kv_indptr = self.paged_kv_indptr.copy_to_gpu(num_reqs + 1) # write self.paged_kv_indices inplace num_actual_pages = num_blocks_np.sum().item() - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages] _copy_page_indices_kernel[(num_reqs, )]( paged_kv_indices, block_table_tensor, @@ -374,7 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # write self.paged_kv_last_page_len_cpu inplace paged_kv_last_page_len_np = seq_lens_np % page_size - self.paged_kv_last_page_len_np[:num_reqs] = np.where( + self.paged_kv_last_page_len.np[:num_reqs] = np.where( paged_kv_last_page_len_np == 0, page_size, paged_kv_last_page_len_np, @@ -418,8 +403,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] - paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] + paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[:1 + num_reqs] + paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs] if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() @@ -495,14 +480,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - paged_kv_indptr_cpu[-1]) + self.paged_kv_indptr.np[1 + num_decodes:1 + + num_input_tokens].fill( + paged_kv_indptr_cpu[-1]) # Fill the remaining paged_kv_last_page_len_cpu with 1. # This is because flashinfer treats 0 as a full page # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) + self.paged_kv_last_page_len.np[ + num_decodes:num_input_tokens].fill(1) else: num_input_tokens = num_decodes @@ -515,9 +500,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # in atten_metadata when using cudagraph. fast_plan_decode( attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], + self.paged_kv_indptr.cpu[:num_input_tokens + 1], paged_kv_indices, - self.paged_kv_last_page_len_cpu[:num_input_tokens], + self.paged_kv_last_page_len.cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], self.num_qo_heads, self.num_kv_heads,