diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 1177d25e300c..3ac43ea4952d 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Any import numba -import numba.types as types import numpy as np import torch @@ -147,80 +146,42 @@ class InputBatch: ) -# NOTE: With the type annotations, this function is pre-compiled -# before the first call. -@numba.jit( - [ - types.none( - types.int32[:], # idx_mapping - types.int32[:], # num_scheduled_tokens - types.int32[:, :], # prefill_token_ids - types.int32[:], # num_computed_prefill_tokens - types.int32[:], # prefill_len - types.int32[:], # input_ids - types.int32[:], # query_start_loc - ) - ], - nopython=True, - cache=True, -) +@numba.njit(cache=True) def _prepare_prefill_inputs( - idx_mapping: np.ndarray, # batch_idx -> req_idx - num_scheduled_tokens: np.ndarray, # [B] + idx_mapping: np.ndarray, # [B] + query_lens: np.ndarray, # [B] + query_start_loc: np.ndarray, # [B + 1] prefill_token_ids: np.ndarray, # [N, max_model_len] num_computed_prefill_tokens: np.ndarray, # [N] - prefill_len: np.ndarray, # [N] input_ids: np.ndarray, # [num_input_tokens] - query_start_loc: np.ndarray, # [B + 1] ) -> None: - num_reqs = num_scheduled_tokens.shape[0] - query_start_loc[0] = 0 - - cu_num_tokens = 0 + num_reqs = idx_mapping.shape[0] + query_starts = query_start_loc[:num_reqs] + query_ends = query_start_loc[1 : num_reqs + 1] + starts = num_computed_prefill_tokens[idx_mapping] + ends = starts + query_lens for i in range(num_reqs): - req_idx = idx_mapping[i] - query_len = num_scheduled_tokens[i] - - start = num_computed_prefill_tokens[req_idx] - end = min(start + query_len, prefill_len[req_idx]) - n = end - start - - start_idx = cu_num_tokens - input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end] - - cu_num_tokens = start_idx + query_len - query_start_loc[i + 1] = cu_num_tokens - - # Pad the inputs for CUDA graphs. - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - query_start_loc[num_reqs + 1 :].fill(cu_num_tokens) + input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[ + idx_mapping[i], starts[i] : ends[i] + ] def prepare_prefill_inputs( idx_mapping: np.ndarray, num_scheduled_tokens: np.ndarray, - total_num_tokens: int, + query_start_loc: np.ndarray, prefill_token_ids: np.ndarray, num_computed_prefill_tokens: np.ndarray, - prefill_len: np.ndarray, - input_ids: CpuGpuBuffer, - query_start_loc: CpuGpuBuffer, + input_ids: np.ndarray, ) -> None: _prepare_prefill_inputs( idx_mapping, num_scheduled_tokens, + query_start_loc, prefill_token_ids, num_computed_prefill_tokens, - prefill_len, - input_ids.np, - query_start_loc.np, + input_ids, ) - input_ids.copy_to_gpu(total_num_tokens) - # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens - # tensors from CPU to GPU, because they may include paddings needed - # for full CUDA graph mode. - query_start_loc.copy_to_gpu() @triton.jit diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 205298a415d4..e0ed183d3c5b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -502,20 +502,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) - # Copy prefill tokens from CPU to GPU and get query_start_loc. + # Get query_start_loc. + np.cumsum( + num_scheduled_tokens, + out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1], + ) + # Pad for full CUDA graph mode. + # Some attention backends like FA3 require query_start_loc to be non-decreasing. + self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens + self.input_buffers.query_start_loc.copy_to_gpu() + query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] + query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] + + # Copy prefill tokens from CPU to GPU. prepare_prefill_inputs( idx_mapping_np, num_scheduled_tokens, - num_tokens, + query_start_loc_np, self.req_states.prefill_token_ids, self.req_states.num_computed_prefill_tokens, - self.req_states.prefill_len.np, - self.input_buffers.input_ids, - self.input_buffers.query_start_loc, + self.input_buffers.input_ids.np, ) - query_start_loc = self.input_buffers.query_start_loc - query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] - query_start_loc_np = query_start_loc.np[: num_reqs + 1] + self.input_buffers.input_ids.copy_to_gpu(num_tokens) # Prepare positions and seq_lens. prepare_pos_seq_lens(