Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-19 09:35:38 +00:00
parent d30c0d50a6
commit 4be2c66e37
3 changed files with 127 additions and 98 deletions

View File

@ -5,9 +5,11 @@ from dataclasses import dataclass
from typing import Any
import numba
import numba.types as types
import numpy as np
import torch
from numba import types
import triton
import triton.language as tl
from vllm.v1.utils import CpuGpuBuffer
@ -161,3 +163,49 @@ def prepare_inputs(
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
@triton.jit
def _combine_last_token_ids_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
query_start_loc_ptr,
seq_lens_ptr,
num_tokens_ptr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
seq_len = tl.load(seq_lens_ptr + batch_idx)
num_tokens = tl.load(num_tokens_ptr + req_state_idx)
if seq_len < num_tokens:
# Chunked prefilling.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
if last_token_id == -1:
return
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs, )](
input_ids,
idx_mapping,
last_token_ids,
query_start_loc,
seq_lens,
num_tokens,
)
return input_ids

View File

@ -27,6 +27,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dist_utils import (all_gather_sampler_output,
evenly_split)
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
combine_last_token_ids,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
@ -158,8 +159,8 @@ class GPUModelRunner:
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@ -205,7 +206,7 @@ class GPUModelRunner:
[] for _ in range(self.block_tables.num_kv_cache_groups))
overwrite: list[bool] = []
# Add new requests to the cached states.
# Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
self.req_states.add_request(
@ -223,7 +224,7 @@ class GPUModelRunner:
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Update the states of the running/resumed requests.
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id]
@ -237,9 +238,6 @@ class GPUModelRunner:
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
self.req_states.num_computed_tokens[req_index] = (
cached_reqs.num_computed_tokens[i])
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
@ -275,54 +273,61 @@ class GPUModelRunner:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
input_ids = self.input_buffers.input_ids
positions = self.input_buffers.positions
query_start_loc = self.input_buffers.query_start_loc
seq_lens = self.input_buffers.seq_lens
prepare_inputs(
idx_mapping_np,
self.req_states.token_ids,
self.req_states.prompt_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
self.input_buffers.input_ids.np,
self.input_buffers.positions.np,
self.input_buffers.query_start_loc.np,
self.input_buffers.seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
self.input_buffers.positions.copy_to_gpu(num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
self.input_buffers.query_start_loc.copy_to_gpu()
self.input_buffers.seq_lens.copy_to_gpu()
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1]
max_query_len = int(num_scheduled_tokens.max())
seq_lens.copy_to_gpu()
seq_lens_cpu = seq_lens.cpu[:num_reqs]
seq_lens_np = seq_lens.np[:num_reqs]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.input_buffers.seq_lens.np[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
max_seq_len = int(seq_lens_np.max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
num_computed_tokens_np = self.req_states.num_computed_tokens[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
is_chunked_prefilling = (seq_lens_np
< self.req_states.num_tokens[idx_mapping_np])
# Some input token ids are directly read from the last sampled tokens.
combine_last_token_ids(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.num_tokens.copy_to_gpu(),
)
# Slot mappings: [num_kv_cache_groups, num_tokens]
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, positions.gpu[:num_tokens])
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens])
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np])
# Whether the request is chunked-prefilling or not.
is_chunked_prefilling = (
seq_lens_np < self.req_states.num_tokens.np[idx_mapping_np])
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
num_logits_indices = logits_indices.size(0)
# Layer name -> attention metadata.
attn_metadata: dict[str, Any] = {}
for i, kv_cache_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
kv_cache_groups = self.kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
@ -352,6 +357,8 @@ class GPUModelRunner:
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@ -361,8 +368,8 @@ class GPUModelRunner:
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids.gpu,
positions=positions.gpu,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
@ -412,10 +419,20 @@ class GPUModelRunner:
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> AsyncOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np
self.req_states.num_tokens.np[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
@ -450,8 +467,8 @@ class GPUModelRunner:
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
input_ids=input_batch.input_ids,
positions=input_batch.positions,
)
sampler_output = self.sample(hidden_states, input_batch)

View File

@ -3,8 +3,6 @@
from dataclasses import dataclass
from typing import Optional
import numba
import numba.types as types
import numpy as np
import torch
@ -76,21 +74,22 @@ class RequestState:
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = np.zeros(
# NOTE(woosuk): Strictly speaking, it contains prompt + some output
# because of preemption.
self.prompt_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Last sampled token ids.
self.last_token = torch.zeros(
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
1,
dtype=torch.int64,
device=device,
)
# Sampling parameters.
@ -110,6 +109,12 @@ class RequestState:
device=self.device,
pin_memory=self.pin_memory)
def _make_buffer(self, size: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@ -126,11 +131,14 @@ class RequestState:
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
# NOTE(woosuk): Strictly speaking, "prompt_len" here may include
# output tokens, if the request is resumed from preemption.
prompt_len = len(prompt_token_ids)
self.num_tokens[req_idx] = prompt_len
self.num_prompt_tokens[req_idx] = prompt_len
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
self.prompt_token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_tokens.np[req_idx] = prompt_len
self.num_computed_tokens[req_idx] = num_computed_tokens
# TODO(woosuk): Optimize.
self.last_sampled_tokens[req_idx].fill_(-1)
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
@ -197,50 +205,6 @@ class RequestState:
max_num_logprobs=max_num_logprobs,
)
def append_token_ids(
self,
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
) -> None:
_append_token_ids(
req_indices,
sampled_ids,
num_sampled_tokens,
self.token_ids,
self.num_tokens,
)
@numba.jit(
[
types.none(
types.int32[:],
types.int64[:, :],
types.int32[:],
types.int32[:, :],
types.int32[:],
)
],
nopython=True,
cache=True,
)
def _append_token_ids(
req_indices: np.ndarray,
sampled_ids: np.ndarray,
num_sampled_tokens: np.ndarray,
token_ids: np.ndarray,
num_tokens: np.ndarray,
) -> None:
num_reqs = num_sampled_tokens.shape[0]
for i in range(num_reqs):
req_idx = req_indices[i]
n = num_sampled_tokens[i]
start_idx = num_tokens[req_idx]
end_idx = start_idx + n
token_ids[req_idx, start_idx:end_idx] = sampled_ids[i, :n]
num_tokens[req_idx] = end_idx
class Param: