Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 13:10:35 -07:00
parent 92f337faeb
commit cbdb47dc01
2 changed files with 79 additions and 10 deletions

View File

@ -16,6 +16,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
@ -356,15 +357,21 @@ class GPUModelRunner:
self,
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
return num_sampled_tokens
# Update the token IDs and the number of tokens.
sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_np = sampled_token_ids_cpu.numpy()
self.req_states.append_token_ids(
input_batch.idx_mapping_np,
sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
)
return sampled_token_ids_np, num_sampled_tokens
def execute_model(
self,
@ -372,17 +379,19 @@ class GPUModelRunner:
):
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return
return EMPTY_MODEL_RUNNER_OUTPUT
input_batch = self.prepare_inputs(scheduler_output)
num_tokens = input_batch.num_tokens_after_padding
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
)
# Compute logits to sample next tokens.
@ -393,5 +402,19 @@ class GPUModelRunner:
prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
input_batch)
output = self.postprocess(sampler_output, input_batch)
return output
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
sampler_output, input_batch)
req_id_to_index = {
req_id: i
for i, req_id in enumerate(input_batch.req_ids)
}
return ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids_np.tolist(),
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)

View File

@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Optional, Union
import numba
import numba.types as types
import numpy as np
import torch
@ -160,3 +162,47 @@ class RequestState:
if self.pin_memory:
cpu_tensor = cpu_tensor.pin_memory()
return cpu_tensor.to(device=self.device, non_blocking=True)
def append_token_ids(
self,
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
) -> None:
_append_token_ids(
req_indices,
sampled_ids,
num_sampled_tokens,
self.token_ids,
self.num_tokens,
)
@numba.jit(
[
types.none(
types.int32[:],
types.int64[:, :],
types.int32[:],
types.int32[:, :],
types.int32[:],
)
],
nopython=True,
cache=True,
)
def _append_token_ids(
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
token_ids: np.ndarray,
num_tokens: np.ndarray,
) -> None:
num_reqs = num_sampled_tokens.shape[0]
for i in range(num_reqs):
req_idx = req_indices[i]
n = num_sampled_tokens[i]
start_idx = num_tokens[req_idx]
end_idx = start_idx + n
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
num_tokens[req_idx] = end_idx