From 4be2c66e37f4a85f51b1d6a88d41699406391958 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 19 Sep 2025 09:35:38 +0000 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 50 +++++++++++++++- vllm/v1/worker/gpu/model_runner.py | 95 ++++++++++++++++++------------ vllm/v1/worker/gpu/states.py | 80 +++++++------------------ 3 files changed, 127 insertions(+), 98 deletions(-) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index e1000e52c5f69..95c9ecee6ffc8 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -5,9 +5,11 @@ from dataclasses import dataclass from typing import Any import numba +import numba.types as types import numpy as np import torch -from numba import types +import triton +import triton.language as tl from vllm.v1.utils import CpuGpuBuffer @@ -161,3 +163,49 @@ def prepare_inputs( query_start_loc[num_reqs + 1:].fill(cu_num_tokens) # Fill unused with 0 for full cuda graph mode. seq_lens[num_reqs:].fill(0) + + +@triton.jit +def _combine_last_token_ids_kernel( + input_ids_ptr, + idx_mapping_ptr, + last_token_ids_ptr, + query_start_loc_ptr, + seq_lens_ptr, + num_tokens_ptr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + seq_len = tl.load(seq_lens_ptr + batch_idx) + num_tokens = tl.load(num_tokens_ptr + req_state_idx) + if seq_len < num_tokens: + # Chunked prefilling. + return + + last_token_id = tl.load(last_token_ids_ptr + req_state_idx) + if last_token_id == -1: + return + + end = tl.load(query_start_loc_ptr + batch_idx + 1) + tl.store(input_ids_ptr + end - 1, last_token_id) + + +def combine_last_token_ids( + input_ids: torch.Tensor, + idx_mapping: torch.Tensor, + last_token_ids: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, + num_tokens: torch.Tensor, +) -> torch.Tensor: + num_reqs = seq_lens.shape[0] + _combine_last_token_ids_kernel[(num_reqs, )]( + input_ids, + idx_mapping, + last_token_ids, + query_start_loc, + seq_lens, + num_tokens, + ) + return input_ids diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 4a7ed7a6af40b..4f22c70e732f7 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -27,6 +27,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output, evenly_split) from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, + combine_last_token_ids, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata @@ -158,8 +159,8 @@ class GPUModelRunner: num_tokens=num_tokens, ): hidden_states = self.model( - input_ids=input_batch.input_ids[:num_tokens], - positions=input_batch.positions[:num_tokens], + input_ids=input_batch.input_ids, + positions=input_batch.positions, ) sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -205,7 +206,7 @@ class GPUModelRunner: [] for _ in range(self.block_tables.num_kv_cache_groups)) overwrite: list[bool] = [] - # Add new requests to the cached states. + # Add new requests. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id self.req_states.add_request( @@ -223,7 +224,7 @@ class GPUModelRunner: new_block_ids[i].extend(block_ids) overwrite.append(True) - # Update the states of the running/resumed requests. + # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): req_index = self.req_states.req_id_to_index[req_id] @@ -237,9 +238,6 @@ class GPUModelRunner: new_block_ids[group_id].extend(block_ids) overwrite.append(False) - self.req_states.num_computed_tokens[req_index] = ( - cached_reqs.num_computed_tokens[i]) - if req_indices: self.block_tables.append_block_ids( req_indices=req_indices, @@ -275,54 +273,61 @@ class GPUModelRunner: # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) - input_ids = self.input_buffers.input_ids - positions = self.input_buffers.positions - query_start_loc = self.input_buffers.query_start_loc - seq_lens = self.input_buffers.seq_lens prepare_inputs( idx_mapping_np, - self.req_states.token_ids, + self.req_states.prompt_token_ids, self.req_states.num_computed_tokens, num_scheduled_tokens, - input_ids.np, - positions.np, - query_start_loc.np, - seq_lens.np, + self.input_buffers.input_ids.np, + self.input_buffers.positions.np, + self.input_buffers.query_start_loc.np, + self.input_buffers.seq_lens.np, ) - input_ids.copy_to_gpu(num_tokens) - positions.copy_to_gpu(num_tokens) - + self.input_buffers.input_ids.copy_to_gpu(num_tokens) + self.input_buffers.positions.copy_to_gpu(num_tokens) # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens # tensors from CPU to GPU, because they may include paddings needed # for full CUDA graph mode. - query_start_loc.copy_to_gpu() - query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] + self.input_buffers.query_start_loc.copy_to_gpu() + self.input_buffers.seq_lens.copy_to_gpu() + query_start_loc = self.input_buffers.query_start_loc query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] + query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] max_query_len = int(num_scheduled_tokens.max()) - - seq_lens.copy_to_gpu() - seq_lens_cpu = seq_lens.cpu[:num_reqs] - seq_lens_np = seq_lens.np[:num_reqs] + seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs] + seq_lens_cpu = self.input_buffers.seq_lens.np[:num_reqs] + seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs] max_seq_len = int(seq_lens_np.max()) - seq_lens_gpu = seq_lens.gpu[:num_reqs] - num_computed_tokens_np = self.req_states.num_computed_tokens[ - idx_mapping_np] - num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) - is_chunked_prefilling = (seq_lens_np - < self.req_states.num_tokens[idx_mapping_np]) + # Some input token ids are directly read from the last sampled tokens. + combine_last_token_ids( + self.input_buffers.input_ids.gpu, + idx_mapping, + self.req_states.last_sampled_tokens, + query_start_loc_gpu, + seq_lens_gpu, + self.req_states.num_tokens.copy_to_gpu(), + ) - # Slot mappings: [num_kv_cache_groups, num_tokens] + # Compute slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc_gpu, positions.gpu[:num_tokens]) + query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]) + num_computed_tokens_cpu = torch.from_numpy( + self.req_states.num_computed_tokens[idx_mapping_np]) + + # Whether the request is chunked-prefilling or not. + is_chunked_prefilling = ( + seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np]) + + # Logits indices to sample next token from. logits_indices = query_start_loc_gpu[1:] - 1 num_logits_indices = logits_indices.size(0) # Layer name -> attention metadata. attn_metadata: dict[str, Any] = {} - for i, kv_cache_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + kv_cache_groups = self.kv_cache_config.kv_cache_groups + for i, kv_cache_spec in enumerate(kv_cache_groups): block_table = block_tables[i] slot_mapping = slot_mappings[i] @@ -352,6 +357,8 @@ class GPUModelRunner: for layer_name in kv_cache_spec.layer_names: attn_metadata[layer_name] = metadata + input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] + positions = self.input_buffers.positions.gpu[:num_tokens_after_padding] return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -361,8 +368,8 @@ class GPUModelRunner: num_tokens=num_tokens, num_tokens_after_padding=num_tokens_after_padding, is_chunked_prefilling=is_chunked_prefilling, - input_ids=input_ids.gpu, - positions=positions.gpu, + input_ids=input_ids, + positions=positions, attn_metadata=attn_metadata, logits_indices=logits_indices, ) @@ -412,10 +419,20 @@ class GPUModelRunner: sampler_output: SamplerOutput, input_batch: InputBatch, ) -> AsyncOutput: + # Store the last sampled token ids. + self.req_states.last_sampled_tokens[input_batch.idx_mapping] = ( + sampler_output.sampled_token_ids) + # Get the number of sampled tokens. # 0 if chunked-prefilling, 1 if not. is_chunked_prefilling = input_batch.is_chunked_prefilling num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) + # Increment the number of tokens. + idx_mapping_np = input_batch.idx_mapping_np + self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens + # Increment the number of computed tokens. + self.req_states.num_computed_tokens[idx_mapping_np] += ( + input_batch.num_scheduled_tokens) model_runner_output = ModelRunnerOutput( req_ids=input_batch.req_ids, @@ -450,8 +467,8 @@ class GPUModelRunner: num_tokens=num_tokens, ): hidden_states = self.model( - input_ids=input_batch.input_ids[:num_tokens], - positions=input_batch.positions[:num_tokens], + input_ids=input_batch.input_ids, + positions=input_batch.positions, ) sampler_output = self.sample(hidden_states, input_batch) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 4deabd2439097..721cceadafbbe 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -3,8 +3,6 @@ from dataclasses import dataclass from typing import Optional -import numba -import numba.types as types import numpy as np import torch @@ -76,21 +74,22 @@ class RequestState: self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_reqs)) - # TODO(woosuk): Because the token_ids tensor can be very big, we only - # initialize it on CPU memory. - self.token_ids = np.zeros( + # NOTE(woosuk): Strictly speaking, it contains prompt + some output + # because of preemption. + self.prompt_token_ids = np.zeros( (self.max_num_reqs, self.max_model_len), dtype=np.int32, ) - self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) - self.num_prompt_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) - # Last sampled token ids. - self.last_token = torch.zeros( + # Last sampled tokens. + self.last_sampled_tokens = torch.zeros( self.max_num_reqs, - dtype=torch.int32, - device=self.device, + 1, + dtype=torch.int64, + device=device, ) # Sampling parameters. @@ -110,6 +109,12 @@ class RequestState: device=self.device, pin_memory=self.pin_memory) + def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory) + @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -126,11 +131,14 @@ class RequestState: self.req_id_to_index[req_id] = req_idx self.index_to_req_id[req_idx] = req_id + # NOTE(woosuk): Strictly speaking, "prompt_len" here may include + # output tokens, if the request is resumed from preemption. prompt_len = len(prompt_token_ids) - self.num_tokens[req_idx] = prompt_len - self.num_prompt_tokens[req_idx] = prompt_len - self.token_ids[req_idx, :prompt_len] = prompt_token_ids + self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids + self.num_tokens.np[req_idx] = prompt_len self.num_computed_tokens[req_idx] = num_computed_tokens + # TODO(woosuk): Optimize. + self.last_sampled_tokens[req_idx].fill_(-1) self.temperature.np[req_idx] = sampling_params.temperature self.top_p.np[req_idx] = sampling_params.top_p @@ -197,50 +205,6 @@ class RequestState: max_num_logprobs=max_num_logprobs, ) - def append_token_ids( - self, - req_indices: np.ndarray, - sampled_ids: np.ndarray, - num_sampled_tokens: np.ndarray, - ) -> None: - _append_token_ids( - req_indices, - sampled_ids, - num_sampled_tokens, - self.token_ids, - self.num_tokens, - ) - - -@numba.jit( - [ - types.none( - types.int32[:], - types.int64[:, :], - types.int32[:], - types.int32[:, :], - types.int32[:], - ) - ], - nopython=True, - cache=True, -) -def _append_token_ids( - req_indices: np.ndarray, - sampled_ids: np.ndarray, - num_sampled_tokens: np.ndarray, - token_ids: np.ndarray, - num_tokens: np.ndarray, -) -> None: - num_reqs = num_sampled_tokens.shape[0] - for i in range(num_reqs): - req_idx = req_indices[i] - n = num_sampled_tokens[i] - start_idx = num_tokens[req_idx] - end_idx = start_idx + n - token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n] - num_tokens[req_idx] = end_idx - class Param: