From a8e7071924f8c1e61465eb3277299e3940872357 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 19 Sep 2025 08:33:47 -0700 Subject: [PATCH] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 37 +++++++++++++++++++++++++++++- vllm/v1/worker/gpu/model_runner.py | 21 ++++++----------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 95c9ecee6ffc8..5cca3f1442714 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -128,7 +128,7 @@ class InputBatch: nopython=True, cache=True, ) -def prepare_inputs( +def _prepare_inputs( idx_mapping: np.ndarray, # batch_idx -> req_idx token_ids: np.ndarray, # [N, max_model_len] num_computed_tokens: np.ndarray, # [N] @@ -165,6 +165,41 @@ def prepare_inputs( seq_lens[num_reqs:].fill(0) +def prepare_inputs( + idx_mapping: np.ndarray, + prompt_token_ids: np.ndarray, + num_computed_tokens: np.ndarray, + num_scheduled_tokens: np.ndarray, + input_ids: CpuGpuBuffer, + positions: CpuGpuBuffer, + query_start_loc: CpuGpuBuffer, + seq_lens: CpuGpuBuffer, + num_tokens: int, +) -> tuple[np.ndarray, np.ndarray]: + _prepare_inputs( + idx_mapping, + prompt_token_ids, + num_computed_tokens, + num_scheduled_tokens, + input_ids.np, + positions.np, + query_start_loc.np, + seq_lens.np, + ) + input_ids.copy_to_gpu(num_tokens) + positions.copy_to_gpu(num_tokens) + # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens + # tensors from CPU to GPU, because they may include paddings needed + # for full CUDA graph mode. + query_start_loc.copy_to_gpu() + seq_lens.copy_to_gpu() + + num_reqs = num_scheduled_tokens.shape[0] + max_query_len = int(num_scheduled_tokens.max()) + max_seq_len = int(seq_lens.np[:num_reqs].max()) + return max_query_len, max_seq_len + + @triton.jit def _combine_last_token_ids_kernel( input_ids_ptr, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 4f22c70e732f7..446de93cc430e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -273,31 +273,24 @@ class GPUModelRunner: # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) - prepare_inputs( + max_query_len, max_seq_len = prepare_inputs( idx_mapping_np, self.req_states.prompt_token_ids, self.req_states.num_computed_tokens, num_scheduled_tokens, - self.input_buffers.input_ids.np, - self.input_buffers.positions.np, - self.input_buffers.query_start_loc.np, - self.input_buffers.seq_lens.np, + self.input_buffers.input_ids, + self.input_buffers.positions, + self.input_buffers.query_start_loc, + self.input_buffers.seq_lens, + num_tokens, ) - self.input_buffers.input_ids.copy_to_gpu(num_tokens) - self.input_buffers.positions.copy_to_gpu(num_tokens) - # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens - # tensors from CPU to GPU, because they may include paddings needed - # for full CUDA graph mode. - self.input_buffers.query_start_loc.copy_to_gpu() - self.input_buffers.seq_lens.copy_to_gpu() + query_start_loc = self.input_buffers.query_start_loc query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] - max_query_len = int(num_scheduled_tokens.max()) seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] seq_lens_cpu = self.input_buffers.seq_lens.np[:num_reqs] seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] - max_seq_len = int(seq_lens_np.max()) # Some input token ids are directly read from the last sampled tokens. combine_last_token_ids(