[Model Runner V2] Support spec decoding [1/N] (#29274)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-23 10:09:06 -08:00 committed by GitHub
parent 7f12c82fa6
commit b004c00418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 347 additions and 26 deletions

View File

@ -35,6 +35,7 @@ class InputBuffers:
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
@ -64,6 +65,7 @@ class InputBatch:
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
num_draft_tokens: int
# [num_reqs + 1]
query_start_loc: torch.Tensor
@ -80,8 +82,10 @@ class InputBatch:
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [num_reqs]
# [total_num_logits]
logits_indices: torch.Tensor
# [num_reqs + 1]
cu_num_logits: torch.Tensor
@classmethod
def make_dummy(
@ -118,6 +122,7 @@ class InputBatch:
positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
@ -126,6 +131,7 @@ class InputBatch:
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
num_draft_tokens=0,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
@ -134,6 +140,7 @@ class InputBatch:
positions=positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
@ -279,19 +286,53 @@ def _combine_sampled_and_draft_tokens_kernel(
query_start_loc_ptr,
seq_lens_ptr,
prefill_len_ptr,
draft_tokens_ptr,
draft_tokens_stride,
cu_num_logits_ptr,
logits_indices_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
# Get the number of logits and draft tokens.
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
num_logits = cu_num_logits_end - cu_num_logits_start
num_draft_tokens = num_logits - 1
# Compute the logits indices.
block = tl.arange(0, BLOCK_SIZE)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
logits_start = query_end - num_logits
tl.store(
logits_indices_ptr + cu_num_logits_start + block,
logits_start + block,
mask=block < num_logits,
)
seq_len = tl.load(seq_lens_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if seq_len <= prefill_len:
# Handling prefill tokens.
# Handling prefill tokens. No sampled or draft tokens.
return
# Write the last sampled token ID to input_ids.
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
end = tl.load(query_start_loc_ptr + batch_idx + 1)
tl.store(input_ids_ptr + end - 1, last_token_id)
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
# Write the draft tokens (if any) to input_ids.
if num_draft_tokens > 0:
mask = block < num_draft_tokens
draft_tokens = tl.load(
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
mask=mask,
)
tl.store(
input_ids_ptr + query_end - num_draft_tokens + block,
draft_tokens,
mask=mask,
)
def combine_sampled_and_draft_tokens(
@ -301,8 +342,18 @@ def combine_sampled_and_draft_tokens(
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
prefill_len: torch.Tensor,
draft_tokens: torch.Tensor,
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
num_logits,
dtype=torch.int64,
device=input_ids.device,
)
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
input_ids,
idx_mapping,
@ -310,35 +361,80 @@ def combine_sampled_and_draft_tokens(
query_start_loc,
seq_lens,
prefill_len,
draft_tokens,
draft_tokens.stride(0),
cu_num_logits,
logits_indices,
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
# in addition to all draft tokens.
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return input_ids
return logits_indices
@triton.jit
def _update_num_computed_tokens_kernel(
def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
query_start_loc_ptr,
cu_num_logits_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
start = tl.load(query_start_loc_ptr + req_id)
end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = end - start
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
n = tl.load(num_computed_tokens_ptr + req_state_idx)
tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
num_computed += query_len
# Consider the rejected tokens in spec decoding.
if num_sampled > 0:
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
logits_start = tl.load(cu_num_logits_ptr + req_id)
logits_end = tl.load(cu_num_logits_ptr + req_id + 1)
num_logits = logits_end - logits_start
num_rejected = num_logits - num_sampled
num_computed -= num_rejected
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
def update_num_computed_tokens(
def post_update(
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_update_num_computed_tokens_kernel[(num_reqs,)](
_post_update_kernel[(num_reqs,)](
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
query_start_loc,
cu_num_logits,
num_warps=1,
)

View File

@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_sampled_and_draft_tokens,
post_update,
prepare_pos_seq_lens,
prepare_prefill_inputs,
update_num_computed_tokens,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_prep_event = None
self.structured_outputs_event = None
if self.speculative_config is not None:
self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
else:
self.do_spec_decode = False
self.num_speculative_steps = 0
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Get the number of draft tokens for each request.
if not scheduler_output.scheduled_spec_decode_tokens:
# No draft token scheduled (common case).
total_num_draft_tokens = 0
total_num_logits = num_reqs
cu_num_logits = torch.arange(
num_reqs + 1, device=self.device, dtype=torch.int32
)
else:
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
num_draft_tokens = np.array(
[
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
for req_id in req_ids
],
dtype=np.int32,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
np.cumsum(
num_draft_tokens + 1,
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
)
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens.
combine_sampled_and_draft_tokens(
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens,
self.req_states.prefill_len.gpu,
self.req_states.draft_tokens,
cu_num_logits,
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
# Get num_computed_tokens.
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# num_computed_tokens_cpu. This works for most cases.
@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
def sample(
@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
assert input_batch.num_draft_tokens == 0
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_buffers,
)
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefilling = input_batch.seq_lens < prefill_len
num_sampled = (~is_chunked_prefilling).int()
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int()
else:
# Draft tokens for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
input_ids,
input_batch.cu_num_logits,
self.num_speculative_steps,
)
num_sampled *= ~is_chunked_prefilling
sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
return sampler_output, num_sampled
def compute_prompt_logprobs(
@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled: torch.Tensor,
) -> None:
# Update the number of computed tokens.
update_num_computed_tokens(
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
self.req_states.last_sampled_tokens,
sampled_tokens,
num_sampled,
input_batch.query_start_loc,
input_batch.cu_num_logits,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
# TODO(woosuk): Simplify this.
@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.prefill_len.np[idx_mapping_np],
)
# Store the last sampled token ids.
last_sampled = sampled_tokens
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
def get_cudagraph_and_dp_padding(
self,
scheduler_output: SchedulerOutput,
@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
)
if input_batch.num_draft_tokens > 0:
sampling_metadata = self.req_states.expand_sampling_metadata(
sampling_metadata, input_batch.cu_num_logits
)
if self.lora_config:
# Activate LoRA adapters.

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@ -7,6 +7,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
@ -63,6 +64,7 @@ class RequestState:
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
@ -70,6 +72,7 @@ class RequestState:
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.num_speculative_steps = num_speculative_steps
self.vocab_size = vocab_size
self.device = device
self.pin_memory = pin_memory
@ -100,6 +103,14 @@ class RequestState:
device=device,
)
# Draft tokens.
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
# LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
@ -226,6 +237,17 @@ class RequestState:
max_num_logprobs=max_num_logprobs,
)
def expand_sampling_metadata(
self,
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
) -> SamplingMetadata:
# For draft tokens, we need to expand the sampling param tensors as
# each request samples multiple tokens in each step.
return expand_sampling_metadata(
sampling_metadata, cu_num_logits, self.num_speculative_steps
)
def make_lora_inputs(
self,
req_ids: list[str],
@ -270,3 +292,75 @@ class Param:
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
temp_ptr,
expanded_temp_ptr,
top_p_ptr,
expanded_top_p_ptr,
top_k_ptr,
expanded_top_k_ptr,
seeds_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
block = tl.arange(0, BLOCK_SIZE)
mask = block < num_tokens
temp = tl.load(temp_ptr + req_idx)
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
if top_p_ptr is not None:
top_p = tl.load(top_p_ptr + req_idx)
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
if top_k_ptr is not None:
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
def expand_sampling_metadata(
sampling_metadata: SamplingMetadata,
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> SamplingMetadata:
total_num_logits = sampling_metadata.pos.shape[0]
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
_expand_sampling_metadata_kernel[(num_reqs,)](
sampling_metadata.temperature,
expanded_temp,
sampling_metadata.top_p,
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
)
return SamplingMetadata(
temperature=expanded_temp,
top_p=expanded_top_p,
top_k=expanded_top_k,
seeds=expanded_seeds,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
)