diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 4510a1c5ca1e9..5aa1a33d851cc 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -18,7 +18,6 @@ from vllm.v1.kv_cache_interface import ( KVCacheConfig, KVCacheSpec, ) -from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.utils import bind_kv_cache @@ -145,7 +144,8 @@ def build_attn_metadata( attn_metadata_builders: list[AttentionMetadataBuilder], num_reqs: int, num_tokens: int, - query_start_loc: CpuGpuBuffer, + query_start_loc_gpu: torch.Tensor, + query_start_loc_cpu: torch.Tensor, seq_lens: torch.Tensor, seq_lens_np: np.ndarray, num_computed_tokens_cpu: torch.Tensor | None, @@ -153,9 +153,7 @@ def build_attn_metadata( slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, ) -> dict[str, Any]: - query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] - query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] - max_query_len = int(query_start_loc.np[: num_reqs + 1].max()) + max_query_len = int(query_start_loc_cpu.max()) seq_lens = seq_lens[:num_reqs] seq_lens_cpu = torch.from_numpy(seq_lens_np) max_seq_len = int(seq_lens_np.max()) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 6b056641c903d..b5fc2edea130f 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -120,7 +120,8 @@ class CudaGraphManager: attn_metadata_builders=attn_metadata_builders, num_reqs=batch_size, num_tokens=batch_size, - query_start_loc=input_buffers.query_start_loc, + query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1], + query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1], seq_lens=input_buffers.seq_lens, seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32), num_computed_tokens_cpu=None, # FIXME diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 6a78776b0a8a3..ed41e5a1a6c5e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -226,11 +226,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens = torch.zeros( input_batch.num_reqs, dtype=torch.int32, device=self.device ) + query_start_loc = self.input_buffers.query_start_loc + query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1] + query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1] attn_metadata = build_attn_metadata( attn_metadata_builders=self.attn_metadata_builders, num_reqs=input_batch.num_reqs, num_tokens=input_batch.num_tokens, - query_start_loc=self.input_buffers.query_start_loc, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, seq_lens=self.input_buffers.seq_lens, seq_lens_np=input_batch.seq_lens_np, num_computed_tokens_cpu=num_computed_tokens, @@ -515,6 +519,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens self.input_buffers.query_start_loc.copy_to_gpu() query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] + query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1] query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] # Copy prefill tokens from CPU to GPU. @@ -572,7 +577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata_builders=self.attn_metadata_builders, num_reqs=num_reqs, num_tokens=num_tokens, - query_start_loc=self.input_buffers.query_start_loc, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, seq_lens=self.input_buffers.seq_lens, seq_lens_np=seq_lens_np, num_computed_tokens_cpu=num_computed_tokens,