[Model Runner V2] Change bookkeeping logic in preparation for spec decoding (#29194)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-23 09:42:52 -08:00 committed by GitHub
parent 6fb0215eee
commit 7f12c82fa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 269 additions and 140 deletions

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
import numpy as np
import torch
from vllm.v1.outputs import (
@ -18,7 +17,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
num_sampled_tokens: np.ndarray,
num_sampled_tokens: torch.Tensor,
copy_stream: torch.cuda.Stream,
copy_event: torch.cuda.Event,
):
@ -52,6 +51,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
)
else:
self.logprobs_tensors = None
self.num_sampled_tokens = num_sampled_tokens.to("cpu", non_blocking=True)
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
if self.model_runner_output.prompt_logprobs_dict:
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
@ -63,6 +63,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
num_sampled_tokens_np = self.num_sampled_tokens.numpy()
# NOTE(woosuk): The following code is to ensure compatibility with
# the existing model runner.
@ -71,7 +72,7 @@ class AsyncOutput(AsyncModelRunnerOutput):
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
num_reqs = len(sampled_token_ids)
for i in range(num_reqs):
del sampled_token_ids[i][self.num_sampled_tokens[i] :]
del sampled_token_ids[i][num_sampled_tokens_np[i] :]
self.model_runner_output.sampled_token_ids = sampled_token_ids
if self.logprobs_tensors is not None:

View File

@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Any, cast
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionBackend
@ -145,8 +146,9 @@ def build_attn_metadata(
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_np: np.ndarray,
num_computed_tokens_cpu: torch.Tensor | None,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
@ -154,9 +156,9 @@ def build_attn_metadata(
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
seq_lens = seq_lens[:num_reqs]
seq_lens_cpu = torch.from_numpy(seq_lens_np)
max_seq_len = int(seq_lens_np.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
@ -167,7 +169,7 @@ def build_attn_metadata(
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,

View File

@ -101,14 +101,13 @@ class CudaGraphManager:
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
positions = input_buffers.positions[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_buffers.seq_lens[:batch_size] = self.max_model_len
input_buffers.seq_lens[batch_size:] = 0
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
@ -119,6 +118,7 @@ class CudaGraphManager:
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,

View File

@ -32,9 +32,9 @@ class InputBuffers:
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
@ -107,13 +107,15 @@ class InputBatch:
query_start_loc_np = input_buffers.query_start_loc.np[: num_reqs + 1]
query_start_loc = input_buffers.query_start_loc.copy_to_gpu()[: num_reqs + 1]
# seq_len equals to query_len
input_buffers.seq_lens.np[:num_reqs] = num_scheduled_tokens
input_buffers.seq_lens.np[num_reqs:] = 0
seq_lens_np = input_buffers.seq_lens.np[:num_reqs]
seq_lens = input_buffers.seq_lens.copy_to_gpu()[:num_reqs]
seq_lens_np = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
seq_lens_np[-1] += num_tokens % num_reqs
input_buffers.seq_lens[:num_reqs] = num_tokens // num_reqs
input_buffers.seq_lens[num_reqs - 1] += num_tokens % num_reqs
input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
positions = input_buffers.positions.copy_to_gpu(num_tokens)
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
return cls(
@ -141,27 +143,25 @@ class InputBatch:
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:, :], # prefill_token_ids
types.int32[:], # num_computed_prefill_tokens
types.int32[:], # prefill_len
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def _prepare_inputs(
def _prepare_prefill_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N]
prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
@ -170,62 +170,112 @@ def _prepare_inputs(
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start = num_computed_prefill_tokens[req_idx]
end = min(start + query_len, prefill_len[req_idx])
n = end - start
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
cu_num_tokens = end_idx
cu_num_tokens = start_idx + query_len
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_inputs(
def prepare_prefill_inputs(
idx_mapping: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_tokens: np.ndarray,
num_scheduled_tokens: np.ndarray,
total_num_tokens: int,
prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray,
prefill_len: np.ndarray,
input_ids: CpuGpuBuffer,
positions: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_tokens: int,
) -> None:
_prepare_inputs(
_prepare_prefill_inputs(
idx_mapping,
prefill_token_ids,
num_computed_tokens,
num_scheduled_tokens,
prefill_token_ids,
num_computed_prefill_tokens,
prefill_len,
input_ids.np,
positions.np,
query_start_loc.np,
seq_lens.np,
)
input_ids.copy_to_gpu(num_tokens)
positions.copy_to_gpu(num_tokens)
input_ids.copy_to_gpu(total_num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
seq_lens.copy_to_gpu()
return
@triton.jit
def _combine_last_token_ids_kernel(
def _prepare_pos_seq_lens_kernel(
pos_ptr,
seq_lens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
num_computed_tokens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr,
):
req_id = tl.program_id(0)
num_reqs = tl.num_programs(0) - 1
if req_id == num_reqs:
# Pad unused seq_lens as 0 for full CUDA graphs.
for i in tl.range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(seq_lens_ptr + block, 0, mask=mask)
return
req_state_idx = tl.load(idx_mapping_ptr + req_id)
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
seq_len = num_computed_tokens + query_len
tl.store(seq_lens_ptr + req_id, seq_len)
for i in tl.range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
pos = num_computed_tokens + block
tl.store(pos_ptr + start + block, pos, mask=mask)
def prepare_pos_seq_lens(
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
num_computed_tokens: torch.Tensor,
pos: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
# NOTE(woosuk): We do +1 because the last thread block is used
# to pad unused seq_lens as 0 for full CUDA graphs.
_prepare_pos_seq_lens_kernel[(num_reqs + 1,)](
pos,
seq_lens,
idx_mapping,
query_start_loc,
num_computed_tokens,
seq_lens.shape[0],
BLOCK_SIZE=1024,
)
@triton.jit
def _combine_sampled_and_draft_tokens_kernel(
input_ids_ptr,
idx_mapping_ptr,
last_token_ids_ptr,
last_sampled_tokens_ptr,
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
@ -239,26 +289,56 @@ def _combine_last_token_ids_kernel(
# Handling prefill tokens.
return
last_token_id = tl.load(last_token_ids_ptr + req_state_idx)
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
def combine_last_token_ids(
def combine_sampled_and_draft_tokens(
input_ids: torch.Tensor,
idx_mapping: torch.Tensor,
last_token_ids: torch.Tensor,
last_sampled_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
_combine_last_token_ids_kernel[(num_reqs,)](
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
last_token_ids,
last_sampled_tokens,
query_start_loc,
seq_lens,
prefill_len,
)
return input_ids
@triton.jit
def _update_num_computed_tokens_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
query_start_loc_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
n = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)
def update_num_computed_tokens(
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_update_num_computed_tokens_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
query_start_loc,
)

View File

@ -39,8 +39,10 @@ from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_last_token_ids,
prepare_inputs,
combine_sampled_and_draft_tokens,
prepare_pos_seq_lens,
prepare_prefill_inputs,
update_num_computed_tokens,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
@ -179,6 +181,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
# TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
self.kv_caches: list[torch.Tensor] = []
init_kv_cache(
@ -196,8 +201,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
num_computed_tokens_cpu = torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device="cpu"
num_computed_tokens = torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
@ -205,7 +210,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=input_batch.num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
seq_lens_np=input_batch.seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
@ -368,6 +374,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Update the GPU tensors for request states.
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
@ -421,46 +430,60 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
prepare_inputs(
# Copy prefill tokens from CPU to GPU and get query_start_loc.
prepare_prefill_inputs(
idx_mapping_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_tokens,
num_scheduled_tokens,
self.input_buffers.input_ids,
self.input_buffers.positions,
self.input_buffers.query_start_loc,
self.input_buffers.seq_lens,
num_tokens,
self.req_states.prefill_token_ids,
self.req_states.num_computed_prefill_tokens,
self.req_states.prefill_len.np,
self.input_buffers.input_ids,
self.input_buffers.query_start_loc,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
seq_lens_gpu = self.input_buffers.seq_lens.gpu[:num_reqs]
seq_lens_np = self.input_buffers.seq_lens.np[:num_reqs]
# Some input token ids are directly read from the last sampled tokens.
combine_last_token_ids(
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
idx_mapping,
query_start_loc_gpu,
self.req_states.num_computed_tokens,
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens.
combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens_gpu,
self.req_states.prefill_len.copy_to_gpu(),
seq_lens,
self.req_states.prefill_len.gpu,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions.gpu[:num_tokens]
)
num_computed_tokens_cpu = torch.from_numpy(
self.req_states.num_computed_tokens[idx_mapping_np]
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
# Get num_computed_tokens.
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# num_computed_tokens_cpu. This works for most cases.
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
# HACK(woosuk): Only GPU has the exact seq_lens because at this point
# CPU does not know how many draft tokens are accepted/rejected in the
# previous step. Therefore, we use max_model_len to be safe.
# NOTE(woosuk): This only works for FA3 backend.
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
@ -468,14 +491,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=num_tokens,
query_start_loc=self.input_buffers.query_start_loc,
seq_lens=self.input_buffers.seq_lens,
num_computed_tokens_cpu=num_computed_tokens_cpu,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=num_computed_tokens,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions.gpu[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@ -486,7 +510,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_after_padding=num_tokens_after_padding,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens_gpu,
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
input_ids=input_ids,
positions=positions,
@ -500,11 +524,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> SamplerOutput:
) -> tuple[SamplerOutput, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
@ -513,8 +538,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output.grammar_bitmask,
self.input_buffers,
)
sampler_output = self.sampler(logits, sampling_metadata)
return sampler_output
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefilling = input_batch.seq_lens < prefill_len
num_sampled = (~is_chunked_prefilling).int()
return sampler_output, num_sampled
def compute_prompt_logprobs(
self,
@ -527,11 +558,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No request asks for prompt logprobs.
return {}
num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping_np]
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
includes_prompt = num_computed_tokens < prompt_lens - 1
computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
@ -550,8 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_ids[n - 1] = 0
# Handle chunked prompts.
seq_lens = self.input_buffers.seq_lens.np[: input_batch.num_reqs]
is_prompt_chunked = seq_lens < prompt_lens
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
@ -561,7 +592,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
next_prompt_token = int(prefill_token_ids[req_idx, seq_lens[i]])
next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
idx = int(query_start_loc[i + 1] - 1)
# Set the next prompt token.
# NOTE(woosuk): This triggers a GPU operation.
@ -617,48 +648,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def postprocess(
self,
sampler_output: SamplerOutput,
prompt_logprobs_dict: dict[str, LogprobsTensors],
input_batch: InputBatch,
) -> AsyncOutput | ModelRunnerOutput:
# Store the last sampled token ids.
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
sampler_output.sampled_token_ids
sampled_tokens: torch.Tensor,
num_sampled: torch.Tensor,
) -> None:
# Update the number of computed tokens.
update_num_computed_tokens(
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
input_batch.query_start_loc,
)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
idx_mapping_np = input_batch.idx_mapping_np
is_chunked_prefilling = (
input_batch.seq_lens_np < self.req_states.num_tokens[idx_mapping_np]
)
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
# Increment the number of computed tokens.
self.req_states.num_computed_tokens[idx_mapping_np] += (
input_batch.num_scheduled_tokens
computed_prefill = self.req_states.num_computed_prefill_tokens
# TODO(woosuk): Simplify this.
computed_prefill[idx_mapping_np] = np.minimum(
computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
self.req_states.prefill_len.np[idx_mapping_np],
)
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, # type: ignore
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
# Store the last sampled token ids.
last_sampled = sampled_tokens
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
def get_cudagraph_and_dp_padding(
self,
@ -782,6 +792,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
else:
# Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@ -807,13 +818,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
sampler_output = self.sample(
sampler_output, num_sampled_tokens = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
output = self.postprocess(
sampler_output,
prompt_logprobs_dict,
input_batch,
# Prepare the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
# NOTE(woosuk): req_id_to_index is unused in this model runner.
# Only for compatibility with the existing model runner and scheduler.
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, # type: ignore
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return output
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
# Postprocess results and update request states.
# NOTE: This is intentionally done after creating the AsyncOutput,
# ensuring that `copy_event` is recorded before calling postprocess.
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
)
if self.use_async_scheduling:
return async_output
return async_output.get_output()

View File

@ -85,8 +85,12 @@ class RequestState:
dtype=np.int32,
)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Last sampled tokens.
self.last_sampled_tokens = torch.zeros(
@ -145,7 +149,10 @@ class RequestState:
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.num_tokens[req_idx] = prefill_len
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
# Optimize this.
self.num_computed_tokens[req_idx] = num_computed_tokens
if lora_request is not None: