diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index f2d0166724fc5..4710926ce921d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -16,6 +16,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, init_attn_backend, init_kv_cache) @@ -356,15 +357,21 @@ class GPUModelRunner: self, sampler_output: SamplerOutput, input_batch: InputBatch, - ) -> np.ndarray: + ) -> tuple[np.ndarray, np.ndarray]: # Get the number of sampled tokens. # 0 if chunked-prefilling, 1 if not. is_chunked_prefilling = input_batch.is_chunked_prefilling num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) - # Increment the number of tokens. - idx_mapping_np = input_batch.idx_mapping_np - self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens - return num_sampled_tokens + + # Update the token IDs and the number of tokens. + sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu() + sampled_token_ids_np = sampled_token_ids_cpu.numpy() + self.req_states.append_token_ids( + input_batch.idx_mapping_np, + sampled_token_ids_np, + num_sampled_tokens=num_sampled_tokens, + ) + return sampled_token_ids_np, num_sampled_tokens def execute_model( self, @@ -372,17 +379,19 @@ class GPUModelRunner: ): self.update_states(scheduler_output) if scheduler_output.total_num_scheduled_tokens == 0: - return + return EMPTY_MODEL_RUNNER_OUTPUT input_batch = self.prepare_inputs(scheduler_output) + num_tokens = input_batch.num_tokens_after_padding with set_forward_context( input_batch.attn_metadata, self.vllm_config, + num_tokens=num_tokens, ): hidden_states = self.model( - input_ids=input_batch.input_ids, - positions=input_batch.positions, + input_ids=input_batch.input_ids[:num_tokens], + positions=input_batch.positions[:num_tokens], ) # Compute logits to sample next tokens. @@ -393,5 +402,19 @@ class GPUModelRunner: prompt_logprobs = self.compute_prompt_logprobs(hidden_states, input_batch) - output = self.postprocess(sampler_output, input_batch) - return output + sampled_token_ids_np, num_sampled_tokens = self.postprocess( + sampler_output, input_batch) + req_id_to_index = { + req_id: i + for i, req_id in enumerate(input_batch.req_ids) + } + return ModelRunnerOutput( + req_ids=input_batch.req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids_np.tolist(), + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_dict={}, + pooler_output=[], + kv_connector_output=None, + num_nans_in_logits=None, + ) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index cfb675315a9c4..f5e4dea82c27a 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Union +import numba +import numba.types as types import numpy as np import torch @@ -160,3 +162,47 @@ class RequestState: if self.pin_memory: cpu_tensor = cpu_tensor.pin_memory() return cpu_tensor.to(device=self.device, non_blocking=True) + + def append_token_ids( + self, + req_indices: np.ndarray, + sampled_ids: np.ndarray, + num_sampled_tokens: np.ndarray, + ) -> None: + _append_token_ids( + req_indices, + sampled_ids, + num_sampled_tokens, + self.token_ids, + self.num_tokens, + ) + + +@numba.jit( + [ + types.none( + types.int32[:], + types.int64[:, :], + types.int32[:], + types.int32[:, :], + types.int32[:], + ) + ], + nopython=True, + cache=True, +) +def _append_token_ids( + req_indices: np.ndarray, + sampled_ids: np.ndarray, + num_sampled_tokens: np.ndarray, + token_ids: np.ndarray, + num_tokens: np.ndarray, +) -> None: + num_reqs = num_sampled_tokens.shape[0] + for i in range(num_reqs): + req_idx = req_indices[i] + n = num_sampled_tokens[i] + start_idx = num_tokens[req_idx] + end_idx = start_idx + n + token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n] + num_tokens[req_idx] = end_idx