From b004c00418268daa61b3526358b661165a360f7d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 23 Nov 2025 10:09:06 -0800 Subject: [PATCH] [Model Runner V2] Support spec decoding [1/N] (#29274) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 122 ++++++++++++++++-- vllm/v1/worker/gpu/model_runner.py | 86 ++++++++++-- vllm/v1/worker/gpu/spec_decode/__init__.py | 0 .../gpu/spec_decode/rejection_sample.py | 71 ++++++++++ vllm/v1/worker/gpu/states.py | 94 ++++++++++++++ 5 files changed, 347 insertions(+), 26 deletions(-) create mode 100644 vllm/v1/worker/gpu/spec_decode/__init__.py create mode 100644 vllm/v1/worker/gpu/spec_decode/rejection_sample.py diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index b671c093113ba..7675cb45170b5 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -35,6 +35,7 @@ class InputBuffers: self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) + self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) # Structured outputs. self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) @@ -64,6 +65,7 @@ class InputBatch: # sum(num_scheduled_tokens) num_tokens: int num_tokens_after_padding: int + num_draft_tokens: int # [num_reqs + 1] query_start_loc: torch.Tensor @@ -80,8 +82,10 @@ class InputBatch: # layer_name -> Metadata attn_metadata: dict[str, Any] - # [num_reqs] + # [total_num_logits] logits_indices: torch.Tensor + # [num_reqs + 1] + cu_num_logits: torch.Tensor @classmethod def make_dummy( @@ -118,6 +122,7 @@ class InputBatch: positions = input_buffers.positions[:num_tokens] # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 + cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) return cls( req_ids=req_ids, num_reqs=num_reqs, @@ -126,6 +131,7 @@ class InputBatch: num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens, + num_draft_tokens=0, query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, @@ -134,6 +140,7 @@ class InputBatch: positions=positions, attn_metadata=None, # type: ignore logits_indices=logits_indices, + cu_num_logits=cu_num_logits, ) @@ -279,19 +286,53 @@ def _combine_sampled_and_draft_tokens_kernel( query_start_loc_ptr, seq_lens_ptr, prefill_len_ptr, + draft_tokens_ptr, + draft_tokens_stride, + cu_num_logits_ptr, + logits_indices_ptr, + BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + # Get the number of logits and draft tokens. + cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx) + cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1) + num_logits = cu_num_logits_end - cu_num_logits_start + num_draft_tokens = num_logits - 1 + + # Compute the logits indices. + block = tl.arange(0, BLOCK_SIZE) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + logits_start = query_end - num_logits + tl.store( + logits_indices_ptr + cu_num_logits_start + block, + logits_start + block, + mask=block < num_logits, + ) + seq_len = tl.load(seq_lens_ptr + batch_idx) prefill_len = tl.load(prefill_len_ptr + req_state_idx) if seq_len <= prefill_len: - # Handling prefill tokens. + # Handling prefill tokens. No sampled or draft tokens. return + # Write the last sampled token ID to input_ids. last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx) - end = tl.load(query_start_loc_ptr + batch_idx + 1) - tl.store(input_ids_ptr + end - 1, last_token_id) + tl.store(input_ids_ptr + query_end - num_logits, last_token_id) + + # Write the draft tokens (if any) to input_ids. + if num_draft_tokens > 0: + mask = block < num_draft_tokens + draft_tokens = tl.load( + draft_tokens_ptr + req_state_idx * draft_tokens_stride + block, + mask=mask, + ) + tl.store( + input_ids_ptr + query_end - num_draft_tokens + block, + draft_tokens, + mask=mask, + ) def combine_sampled_and_draft_tokens( @@ -301,8 +342,18 @@ def combine_sampled_and_draft_tokens( query_start_loc: torch.Tensor, seq_lens: torch.Tensor, prefill_len: torch.Tensor, + draft_tokens: torch.Tensor, + cu_num_logits: torch.Tensor, + num_logits: int, ) -> torch.Tensor: num_reqs = seq_lens.shape[0] + num_speculative_steps = draft_tokens.shape[-1] + + logits_indices = torch.empty( + num_logits, + dtype=torch.int64, + device=input_ids.device, + ) _combine_sampled_and_draft_tokens_kernel[(num_reqs,)]( input_ids, idx_mapping, @@ -310,35 +361,80 @@ def combine_sampled_and_draft_tokens( query_start_loc, seq_lens, prefill_len, + draft_tokens, + draft_tokens.stride(0), + cu_num_logits, + logits_indices, + # NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token + # in addition to all draft tokens. + BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1), ) - return input_ids + return logits_indices @triton.jit -def _update_num_computed_tokens_kernel( +def _post_update_kernel( idx_mapping_ptr, num_computed_tokens_ptr, + last_sampled_tokens_ptr, + sampled_tokens_ptr, + sampled_tokens_stride, + num_sampled_ptr, query_start_loc_ptr, + cu_num_logits_ptr, ): req_id = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + req_id) - start = tl.load(query_start_loc_ptr + req_id) - end = tl.load(query_start_loc_ptr + req_id + 1) - query_len = end - start + num_sampled = tl.load(num_sampled_ptr + req_id) + if num_sampled > 0: + token_id = tl.load( + sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1 + ) + tl.store(last_sampled_tokens_ptr + req_state_idx, token_id) - n = tl.load(num_computed_tokens_ptr + req_state_idx) - tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len) + query_start = tl.load(query_start_loc_ptr + req_id) + query_end = tl.load(query_start_loc_ptr + req_id + 1) + query_len = query_end - query_start + + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + num_computed += query_len + # Consider the rejected tokens in spec decoding. + if num_sampled > 0: + # NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills. + logits_start = tl.load(cu_num_logits_ptr + req_id) + logits_end = tl.load(cu_num_logits_ptr + req_id + 1) + num_logits = logits_end - logits_start + num_rejected = num_logits - num_sampled + num_computed -= num_rejected + tl.store(num_computed_tokens_ptr + req_state_idx, num_computed) -def update_num_computed_tokens( +def post_update( + # [num_reqs] idx_mapping: torch.Tensor, + # [max_num_reqs] num_computed_tokens: torch.Tensor, + # [max_num_reqs] + last_sampled_tokens: torch.Tensor, + # [num_reqs, num_speculative_steps + 1] + sampled_tokens: torch.Tensor, + # [num_reqs] + num_sampled: torch.Tensor, + # [num_reqs + 1] query_start_loc: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, ) -> None: num_reqs = idx_mapping.shape[0] - _update_num_computed_tokens_kernel[(num_reqs,)]( + _post_update_kernel[(num_reqs,)]( idx_mapping, num_computed_tokens, + last_sampled_tokens, + sampled_tokens, + sampled_tokens.stride(0), + num_sampled, query_start_loc, + cu_num_logits, + num_warps=1, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index bacfbd6c2f465..4b4ee92176f2c 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, combine_sampled_and_draft_tokens, + post_update, prepare_pos_seq_lens, prepare_prefill_inputs, - update_num_computed_tokens, ) from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs +from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_prep_event = None self.structured_outputs_event = None + if self.speculative_config is not None: + self.do_spec_decode = True + self.num_speculative_steps = self.speculative_config.num_speculative_tokens + else: + self.do_spec_decode = False + self.num_speculative_steps = 0 + self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_batched_tokens=self.max_num_tokens, + num_speculative_steps=self.num_speculative_steps, vocab_size=self.vocab_size, device=self.device, pin_memory=self.pin_memory, @@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np = idx_mapping.np[:num_reqs] idx_mapping = idx_mapping.copy_to_gpu(num_reqs) + # Get the number of draft tokens for each request. + if not scheduler_output.scheduled_spec_decode_tokens: + # No draft token scheduled (common case). + total_num_draft_tokens = 0 + total_num_logits = num_reqs + cu_num_logits = torch.arange( + num_reqs + 1, device=self.device, dtype=torch.int32 + ) + else: + draft_tokens = scheduler_output.scheduled_spec_decode_tokens + num_draft_tokens = np.array( + [ + len(draft_tokens[req_id]) if req_id in draft_tokens else 0 + for req_id in req_ids + ], + dtype=np.int32, + ) + total_num_draft_tokens = int(num_draft_tokens.sum()) + total_num_logits = num_reqs + total_num_draft_tokens + + np.cumsum( + num_draft_tokens + 1, + out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1], + ) + cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1) + # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] block_tables = self.block_tables.gather_block_tables(idx_mapping) @@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): seq_lens = self.input_buffers.seq_lens[:num_reqs] # Some input token ids are directly read from the last sampled tokens - # and draft tokens. - combine_sampled_and_draft_tokens( + # and draft tokens. Also, get the logits indices to sample tokens from. + logits_indices = combine_sampled_and_draft_tokens( self.input_buffers.input_ids.gpu, idx_mapping, self.req_states.last_sampled_tokens, query_start_loc_gpu, seq_lens, self.req_states.prefill_len.gpu, + self.req_states.draft_tokens, + cu_num_logits, + total_num_logits, ) # Compute slot mappings: [num_kv_cache_groups, num_tokens] @@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc_gpu, self.input_buffers.positions[:num_tokens] ) - # Logits indices to sample next token from. - logits_indices = query_start_loc_gpu[1:] - 1 - # Get num_computed_tokens. # HACK(woosuk): Here, we use num_computed_tokens on GPU instead of # num_computed_tokens_cpu. This works for most cases. @@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, num_tokens_after_padding=num_tokens_after_padding, + num_draft_tokens=total_num_draft_tokens, query_start_loc=query_start_loc_gpu, query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, @@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions=positions, attn_metadata=attn_metadata, logits_indices=logits_indices, + cu_num_logits=cu_num_logits, ) def sample( @@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if grammar_output is not None: # Apply grammar bitmask to the logits in-place. # TODO(woosuk): Make compatible with spec decoding. + assert input_batch.num_draft_tokens == 0 with async_barrier(self.structured_outputs_event): apply_grammar_bitmask( logits, @@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.input_buffers, ) + # Sample tokens and compute logprobs (if needed). sampler_output = self.sampler(logits, sampling_metadata) + # Get the number of sampled tokens. - # 0 if chunked-prefilling, 1 if not. prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping] is_chunked_prefilling = input_batch.seq_lens < prefill_len - num_sampled = (~is_chunked_prefilling).int() + if input_batch.num_draft_tokens == 0: + # No draft tokens (common case). + # 0 if chunked-prefilling, 1 if not. + num_sampled = (~is_chunked_prefilling).int() + else: + # Draft tokens for spec decoding. + input_ids = input_batch.input_ids[input_batch.logits_indices] + sampled_tokens, num_sampled = rejection_sample( + sampler_output.sampled_token_ids, + input_ids, + input_batch.cu_num_logits, + self.num_speculative_steps, + ) + num_sampled *= ~is_chunked_prefilling + sampler_output.sampled_token_ids = sampled_tokens + # TODO(woosuk): Support logprobs with spec decoding. return sampler_output, num_sampled def compute_prompt_logprobs( @@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_sampled: torch.Tensor, ) -> None: # Update the number of computed tokens. - update_num_computed_tokens( + post_update( input_batch.idx_mapping, self.req_states.num_computed_tokens, + self.req_states.last_sampled_tokens, + sampled_tokens, + num_sampled, input_batch.query_start_loc, + input_batch.cu_num_logits, ) + + # Update the number of computed prefill tokens. idx_mapping_np = input_batch.idx_mapping_np computed_prefill = self.req_states.num_computed_prefill_tokens # TODO(woosuk): Simplify this. @@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.prefill_len.np[idx_mapping_np], ) - # Store the last sampled token ids. - last_sampled = sampled_tokens - self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled - def get_cudagraph_and_dp_padding( self, scheduler_output: SchedulerOutput, @@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata = self.req_states.make_sampling_metadata( input_batch.idx_mapping_np, pos ) + if input_batch.num_draft_tokens > 0: + sampling_metadata = self.req_states.expand_sampling_metadata( + sampling_metadata, input_batch.cu_num_logits + ) if self.lora_config: # Activate LoRA adapters. diff --git a/vllm/v1/worker/gpu/spec_decode/__init__.py b/vllm/v1/worker/gpu/spec_decode/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py new file mode 100644 index 0000000000000..8a7bf28bacbd4 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _rejection_sample_kernel( + sampled_ptr, # [num_reqs, num_speculative_steps + 1] + sampled_stride, + num_sampled_ptr, # [num_reqs] + target_sampled_ptr, # [num_draft_tokens + num_reqs] + input_ids_ptr, # [num_draft_tokens + num_reqs] + cu_num_logits_ptr, # [num_reqs + 1] +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + + num_sampled = 0 + rejected = False + for i in range(num_tokens - 1): + if not rejected: + target_sampled = tl.load(target_sampled_ptr + start_idx + i) + draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1) + tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled) + num_sampled += 1 + if target_sampled != draft_sampled: + rejected = True + if not rejected: + target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) + tl.store( + sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled + ) + num_sampled += 1 + tl.store(num_sampled_ptr + req_idx, num_sampled) + + +def rejection_sample( + # [num_draft_tokens + num_reqs] + target_sampled: torch.Tensor, + # [num_draft_tokens + num_reqs] + input_ids: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, + num_speculative_steps: int, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = cu_num_logits.shape[0] - 1 + sampled = torch.empty( + num_reqs, + num_speculative_steps + 1, + dtype=target_sampled.dtype, + device=target_sampled.device, + ) + num_sampled = torch.empty( + num_reqs, + dtype=torch.int32, + device=target_sampled.device, + ) + _rejection_sample_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + num_sampled, + target_sampled, + input_ids, + cu_num_logits, + num_warps=1, + ) + return sampled, num_sampled diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index e8a3207a3a53e..513d45d95d7cd 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -7,6 +7,7 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams +from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors from vllm.v1.utils import CpuGpuBuffer @@ -63,6 +64,7 @@ class RequestState: max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, + num_speculative_steps: int, vocab_size: int, device: torch.device, pin_memory: bool, @@ -70,6 +72,7 @@ class RequestState: self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens + self.num_speculative_steps = num_speculative_steps self.vocab_size = vocab_size self.device = device self.pin_memory = pin_memory @@ -100,6 +103,14 @@ class RequestState: device=device, ) + # Draft tokens. + self.draft_tokens = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + dtype=torch.int64, + device=device, + ) + # LoRA. self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids.fill(NO_LORA_ID) @@ -226,6 +237,17 @@ class RequestState: max_num_logprobs=max_num_logprobs, ) + def expand_sampling_metadata( + self, + sampling_metadata: SamplingMetadata, + cu_num_logits: torch.Tensor, + ) -> SamplingMetadata: + # For draft tokens, we need to expand the sampling param tensors as + # each request samples multiple tokens in each step. + return expand_sampling_metadata( + sampling_metadata, cu_num_logits, self.num_speculative_steps + ) + def make_lora_inputs( self, req_ids: list[str], @@ -270,3 +292,75 @@ class Param: class ExtraData: lora_request: LoRARequest | None in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) + + +# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None. +@triton.jit +def _expand_sampling_metadata_kernel( + temp_ptr, + expanded_temp_ptr, + top_p_ptr, + expanded_top_p_ptr, + top_k_ptr, + expanded_top_k_ptr, + seeds_ptr, + expanded_seeds_ptr, + cu_num_logits_ptr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + + block = tl.arange(0, BLOCK_SIZE) + mask = block < num_tokens + + temp = tl.load(temp_ptr + req_idx) + tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask) + + if top_p_ptr is not None: + top_p = tl.load(top_p_ptr + req_idx) + tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask) + + if top_k_ptr is not None: + top_k = tl.load(top_k_ptr + req_idx) + tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) + + seed = tl.load(seeds_ptr + req_idx) + tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask) + + +def expand_sampling_metadata( + sampling_metadata: SamplingMetadata, + cu_num_logits: torch.Tensor, + num_speculative_steps: int, +) -> SamplingMetadata: + total_num_logits = sampling_metadata.pos.shape[0] + create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None + expanded_temp = create_empty(sampling_metadata.temperature) + expanded_top_p = create_empty(sampling_metadata.top_p) + expanded_top_k = create_empty(sampling_metadata.top_k) + expanded_seeds = create_empty(sampling_metadata.seeds) + + num_reqs = cu_num_logits.shape[0] - 1 + _expand_sampling_metadata_kernel[(num_reqs,)]( + sampling_metadata.temperature, + expanded_temp, + sampling_metadata.top_p, + expanded_top_p, + sampling_metadata.top_k, + expanded_top_k, + sampling_metadata.seeds, + expanded_seeds, + cu_num_logits, + BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1), + ) + return SamplingMetadata( + temperature=expanded_temp, + top_p=expanded_top_p, + top_k=expanded_top_k, + seeds=expanded_seeds, + pos=sampling_metadata.pos, + max_num_logprobs=sampling_metadata.max_num_logprobs, + )