Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 13:10:35 -07:00
parent 92f337faeb
commit cbdb47dc01
2 changed files with 79 additions and 10 deletions

View File

@ -16,6 +16,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import SamplerOutput from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache) init_attn_backend, init_kv_cache)
@ -356,15 +357,21 @@ class GPUModelRunner:
self, self,
sampler_output: SamplerOutput, sampler_output: SamplerOutput,
input_batch: InputBatch, input_batch: InputBatch,
) -> np.ndarray: ) -> tuple[np.ndarray, np.ndarray]:
# Get the number of sampled tokens. # Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not. # 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np # Update the token IDs and the number of tokens.
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu()
return num_sampled_tokens sampled_token_ids_np = sampled_token_ids_cpu.numpy()
self.req_states.append_token_ids(
input_batch.idx_mapping_np,
sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
)
return sampled_token_ids_np, num_sampled_tokens
def execute_model( def execute_model(
self, self,
@ -372,17 +379,19 @@ class GPUModelRunner:
): ):
self.update_states(scheduler_output) self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0: if scheduler_output.total_num_scheduled_tokens == 0:
return return EMPTY_MODEL_RUNNER_OUTPUT
input_batch = self.prepare_inputs(scheduler_output) input_batch = self.prepare_inputs(scheduler_output)
num_tokens = input_batch.num_tokens_after_padding
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens,
): ):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_batch.input_ids, input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions, positions=input_batch.positions[:num_tokens],
) )
# Compute logits to sample next tokens. # Compute logits to sample next tokens.
@ -393,5 +402,19 @@ class GPUModelRunner:
prompt_logprobs = self.compute_prompt_logprobs(hidden_states, prompt_logprobs = self.compute_prompt_logprobs(hidden_states,
input_batch) input_batch)
output = self.postprocess(sampler_output, input_batch) sampled_token_ids_np, num_sampled_tokens = self.postprocess(
return output sampler_output, input_batch)
req_id_to_index = {
req_id: i
for i, req_id in enumerate(input_batch.req_ids)
}
return ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids_np.tolist(),
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)

View File

@ -3,6 +3,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
import numba
import numba.types as types
import numpy as np import numpy as np
import torch import torch
@ -160,3 +162,47 @@ class RequestState:
if self.pin_memory: if self.pin_memory:
cpu_tensor = cpu_tensor.pin_memory() cpu_tensor = cpu_tensor.pin_memory()
return cpu_tensor.to(device=self.device, non_blocking=True) return cpu_tensor.to(device=self.device, non_blocking=True)
def append_token_ids(
self,
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
) -> None:
_append_token_ids(
req_indices,
sampled_ids,
num_sampled_tokens,
self.token_ids,
self.num_tokens,
)
@numba.jit(
[
types.none(
types.int32[:],
types.int64[:, :],
types.int32[:],
types.int32[:, :],
types.int32[:],
)
],
nopython=True,
cache=True,
)
def _append_token_ids(
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
token_ids: np.ndarray,
num_tokens: np.ndarray,
) -> None:
num_reqs = num_sampled_tokens.shape[0]
for i in range(num_reqs):
req_idx = req_indices[i]
n = num_sampled_tokens[i]
start_idx = num_tokens[req_idx]
end_idx = start_idx + n
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
num_tokens[req_idx] = end_idx