From 6ebaf43ee4a6fbbeba685315d605536db1c0c471 Mon Sep 17 00:00:00 2001 From: Sergei Skvortsov Date: Tue, 7 Oct 2025 21:02:49 +0100 Subject: [PATCH] [V1] Logit processors for rejection sampler (#19482) Signed-off-by: southfreebird Signed-off-by: Sergei Skvortsov Signed-off-by: Sergei Skvortsov Co-authored-by: Sergei Skvortsov Co-authored-by: Nick Hill --- .../logits_processors/test_custom_offline.py | 59 ++++-- tests/v1/sample/test_rejection_sampler.py | 180 +++++++++++++++++- tests/v1/sample/test_sampler.py | 46 ++--- tests/v1/sample/utils.py | 20 ++ tests/v1/worker/test_gpu_input_batch.py | 1 + vllm/v1/sample/logits_processor/__init__.py | 17 ++ vllm/v1/sample/metadata.py | 3 + vllm/v1/sample/ops/bad_words.py | 17 ++ vllm/v1/sample/rejection_sampler.py | 104 ++++++++++ vllm/v1/sample/sampler.py | 96 ++++++---- vllm/v1/worker/gpu_input_batch.py | 17 ++ vllm/v1/worker/gpu_model_runner.py | 3 + 12 files changed, 471 insertions(+), 92 deletions(-) diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index f57a21dce516f..95ddb18491691 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import sys -from typing import Union +from typing import Any, Union import pytest @@ -25,6 +25,7 @@ from tests.v1.logits_processors.utils import entry_points as fake_entry_points from vllm import LLM, SamplingParams from vllm.v1.sample.logits_processor import ( STR_POOLING_REJECTS_LOGITSPROCS, + STR_SPEC_DEC_REJECTS_LOGITSPROCS, LogitsProcessor, ) @@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch): @create_new_process_for_each_test() +@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"]) @pytest.mark.parametrize( "logitproc_source", [ @@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch): CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, ], ) -def test_pooling_rejects_custom_logitsprocs( - monkeypatch, logitproc_source: CustomLogitprocSource +def test_rejects_custom_logitsprocs( + monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource ): """Validate that vLLM engine initialization properly rejects custom - logitsprocs when the model is a pooling model. + logitsprocs when the model is a pooling model or speculative decoding + enabled. Use `LLM` entrypoint. We expect `LLM` initialization to fail before the logitproc is actually loaded. @@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs( monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") random.seed(40) + test_params: dict[str, dict[str, Any]] = { + "pooling": { + "runner": "pooling", + "model": POOLING_MODEL_NAME, + "error_message": STR_POOLING_REJECTS_LOGITSPROCS, + "speculative_config": None, + }, + "spec_dec": { + "runner": "auto", + "model": MODEL_NAME, + "error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS, + "speculative_config": {"model": "ngram", "num_speculative_tokens": 1}, + }, + } + + config = test_params[model_scenario] + + llm_kwargs: dict[str, Any] = { + "runner": config["runner"], + "model": config["model"], + "gpu_memory_utilization": 0.1, + "speculative_config": config["speculative_config"], + } + if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT: - # Scenario: vLLM loads a pooling model and ignores a logitproc that is + # Scenario: vLLM loads a model and ignores a logitproc that is # available at a preconfigured entrypoint # Patch in dummy logitproc entrypoint @@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs( # although they should ignore the entrypoint patch anyway monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork") - llm = LLM( - runner="pooling", - model=POOLING_MODEL_NAME, - gpu_memory_utilization=0.1, - ) + llm = LLM(**llm_kwargs) # Require that no logitsprocs have been loaded worker = llm.llm_engine.model_executor.driver_worker.worker assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0 return - kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: # Scenario: load logitproc based on fully-qualified class name (FQCN) - kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] + llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: # Scenario: load logitproc from provided class object - kwargs["logits_processors"] = [DummyLogitsProcessor] + llm_kwargs["logits_processors"] = [DummyLogitsProcessor] - with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS): - # Require that loading a pooling model alongside the logitproc raises + with pytest.raises(ValueError, match=config["error_message"]): + # Require that loading a model alongside the logitproc raises # the appropriate exception. - LLM( - runner="pooling", - model=POOLING_MODEL_NAME, - gpu_memory_utilization=0.1, - **kwargs, - ) + LLM(**llm_kwargs) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 36e2e2698810b..8df10f8c3afa5 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -6,6 +6,7 @@ import pytest import torch import torch.nn.functional as F +from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata @@ -21,7 +22,9 @@ def rejection_sampler(): def create_logits_tensor( - output_token_ids: list[list[int]], vocab_size: int = 100 + output_token_ids: list[list[int]], + vocab_size: int = 100, + token_idx_to_override: Optional[int] = None, ) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" @@ -33,15 +36,25 @@ def create_logits_tensor( for j, token_id in enumerate(tokens): logits[start_loc + j, token_id] = 100.0 start_loc += len(tokens) + if token_idx_to_override: + logits[:, token_idx_to_override] = 99.0 return logits def create_sampling_metadata( all_greedy: bool, + output_token_ids: Optional[list[list[int]]] = None, + prompt_token_ids: Optional[torch.Tensor] = None, + spec_token_ids: Optional[torch.Tensor] = None, temperature: Optional[torch.Tensor] = None, top_k: Optional[torch.Tensor] = None, top_p: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, + frequency_penalties: Optional[list[float]] = None, + presence_penalties: Optional[list[float]] = None, + repetition_penalties: Optional[list[float]] = None, + bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None, + allowed_token_ids_mask: Optional[torch.Tensor] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set to the given value. Either all greedy or all random sampling @@ -53,6 +66,21 @@ def create_sampling_metadata( else: assert temperature is not None + if any([frequency_penalties, presence_penalties, repetition_penalties]): + no_penalties = False + + assert output_token_ids + assert len(output_token_ids) > 0 + + frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE) + presence_penalties = torch.tensor(presence_penalties, device=DEVICE) + repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE) + else: + no_penalties = True + frequency_penalties = torch.tensor([]) + presence_penalties = torch.tensor([]) + repetition_penalties = torch.tensor([]) + return SamplingMetadata( temperature=temperature, all_greedy=all_greedy, @@ -61,14 +89,15 @@ def create_sampling_metadata( top_k=top_k, generators=generators, max_num_logprobs=0, - no_penalties=False, - prompt_token_ids=None, - frequency_penalties=torch.tensor([]), - presence_penalties=torch.tensor([]), - repetition_penalties=torch.tensor([]), - output_token_ids=[], - allowed_token_ids_mask=None, - bad_words_token_ids={}, + no_penalties=no_penalties, + prompt_token_ids=prompt_token_ids, + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, + output_token_ids=[] if output_token_ids is None else output_token_ids, + spec_token_ids=[] if spec_token_ids is None else spec_token_ids, + allowed_token_ids_mask=allowed_token_ids_mask, + bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids, logitsprocs=LogitsProcessors(), ) @@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p): unmasked_indices=top_p_indices, sampling_metadata=sampling_metadata, ) + + +########################### Tests for Logit Processors ################### +def test_frequency_penalties(rejection_sampler): + """Test rejection sampling with frequency penalties""" + spec_tokens = [[1, 1, 1], [], [1, 1, 1]] + output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens + + num_requsts = len(spec_tokens) + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[2], [3], [4]], + spec_token_ids=spec_tokens, + prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE), + frequency_penalties=[1.5, 1.5, 0.7], + presence_penalties=[0.0] * num_requsts, + repetition_penalties=[1.0] * num_requsts, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + expected = torch.tensor( + [[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +def test_bad_words(rejection_sampler): + """Test rejection sampling with bad words constraints""" + spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]] + output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] + + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[2], [3], [4]], + spec_token_ids=spec_tokens, + bad_words_token_ids={ + 0: [ + [ + 2, + ] + ], + 1: [ + [ + 2, + ] + ], + # Do not apply bad words to the last request + }, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + + expected = torch.tensor( + [[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) + + +def test_allowed_token_ids(rejection_sampler): + """Test rejection sampling with allowed token ids""" + spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]] + output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]] + # Not allowed tokens: + # 0: 0-4 + # 1: 1-5 + # 2: 2-6 + num_allowed_token_ids = 5 + + # Use the token 15 as the sampler choose if a token rejected + logits = create_logits_tensor(output_tokens, token_idx_to_override=15) + + batch_size = len(output_tokens) + _, vocab_size = logits.size() + mask = create_allowed_token_ids( + batch_size=batch_size, + vocab_size=vocab_size, + num_allowed_token_ids=num_allowed_token_ids, + device=logits.device, + ) + metadata = create_sampling_metadata( + all_greedy=True, + output_token_ids=[[], [], []], + spec_token_ids=spec_tokens, + allowed_token_ids_mask=mask, + ) + bonus_token_tensor = torch.tensor( + [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) + + expected = torch.tensor( + [[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]], + dtype=torch.int, + device=logits.device, + ) + assert torch.equal(output, expected) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 5b34e27e79ac0..edc6acae848aa 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - import numpy as np import pytest import torch +from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.logits_processor import LogitsProcessors @@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor( ) -def _create_allowed_token_ids( - batch_size: int, - vocab_size: int, - num_allowed_token_ids: int, - device: torch.device, -) -> Optional[torch.Tensor]: - mask: Optional[torch.Tensor] = None - for i in range(batch_size): - if i % 2 == 1: - continue - if mask is None: - mask = torch.zeros( - (batch_size, vocab_size), dtype=torch.bool, device=device - ) - start = min(i, vocab_size - 1) - end = min(i + num_allowed_token_ids, vocab_size - 1) - mask[i, start:end] = True - return mask - - def _create_bad_words_token_ids( batch_size: int, vocab_size: int, @@ -173,6 +152,7 @@ def _create_default_sampling_metadata( prompt_token_ids, vocab_size, device ), output_token_ids=output_token_ids, + spec_token_ids=[[] for _ in range(batch_size)], frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), @@ -241,7 +221,9 @@ def test_sampler_presence_penalty( ) sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): # Since all tokens initially have the same logits, the non-penalized @@ -293,7 +275,9 @@ def test_sampler_frequency_penalty( sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() @@ -343,7 +327,9 @@ def test_sampler_repetition_penalty( ) sampling_metadata.no_penalties = False sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) + logits = sampler.apply_penalties( + fake_logits, sampling_metadata, sampling_metadata.output_token_ids + ) logits = logits.cpu() for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() @@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids( sampling_metadata = _create_default_sampling_metadata( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) ) - mask = _create_allowed_token_ids( + mask = create_allowed_token_ids( batch_size=batch_size, vocab_size=VOCAB_SIZE, num_allowed_token_ids=num_allowed_token_ids, @@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids( ) sampling_metadata.allowed_token_ids_mask = mask sampler = Sampler() - logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata) + logits = sampler.apply_logits_processors( + fake_logits, sampling_metadata, predict_bonus_token=False + ) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] @@ -444,7 +432,9 @@ def test_sampler_bad_words( sampling_metadata, VOCAB_SIZE ) sampler = Sampler() - logits = sampler.apply_bad_words(fake_logits, sampling_metadata) + logits = sampler.apply_logits_processors( + fake_logits, sampling_metadata, predict_bonus_token=False + ) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index 0f1214e9745c5..b1c63327b852b 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -215,3 +215,23 @@ def fake_apply_logitsprocs( for processor in test_fakes.get_logitsprocs(): logits = processor.apply(logits) return logits + + +def create_allowed_token_ids( + batch_size: int, + vocab_size: int, + num_allowed_token_ids: int, + device: torch.device, +) -> Optional[torch.Tensor]: + mask: Optional[torch.Tensor] = None + for i in range(batch_size): + if i % 2 == 1: + continue + if mask is None: + mask = torch.zeros( + (batch_size, vocab_size), dtype=torch.bool, device=device + ) + start = min(i, vocab_size - 1) + end = min(i + num_allowed_token_ids, vocab_size - 1) + mask[i, start:end] = True + return mask diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index c834577f1adb6..e72bd43ff56e6 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -170,6 +170,7 @@ def _construct_expected_sampling_metadata( repetition_penalties, dtype=torch.float, device=device ), output_token_ids=output_token_ids, + spec_token_ids=[[] for _ in range(len(output_token_ids))], no_penalties=( all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 98c4d8bad02d3..e9935f72c17f2 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -37,6 +37,12 @@ STR_POOLING_REJECTS_LOGITSPROCS = ( "Pooling models do not support custom logits processors." ) +# Error message when the user tries to initialize vLLM with a speculative +# decoding enabled and custom logitsproces +STR_SPEC_DEC_REJECTS_LOGITSPROCS = ( + "Custom logits processors are not supportedwhen speculative decoding is enabled." +) + LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ @@ -185,6 +191,17 @@ def build_logitsprocs( " do not support logits processors." ) return LogitsProcessors() + + # Check if speculative decoding is enabled. + if vllm_config.speculative_config: + if custom_logitsprocs: + raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS) + logger.warning( + "min_p, logit_bias, and min_tokens parameters won't currently work " + "with speculative decoding enabled." + ) + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( ctor(vllm_config, device, is_pin_memory) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 14895db1bd553..e252ace97d27e 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -40,3 +40,6 @@ class SamplingMetadata: # Loaded logits processors logitsprocs: LogitsProcessors + + # Speculative token ids + spec_token_ids: Optional[list[list[int]]] = None diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index faa4c33cc7933..8e2c798dd35ff 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -33,3 +33,20 @@ def apply_bad_words( ) -> None: for i, bad_words_ids in bad_words_token_ids.items(): _apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i]) + + +def apply_bad_words_with_drafts( + logits: torch.Tensor, + bad_words_token_ids: dict[int, list[list[int]]], + past_tokens_ids: list[list[int]], + num_draft_tokens: list[int], +) -> None: + start_idx = 0 + for i, bad_words_ids in bad_words_token_ids.items(): + for draft_idx in range(num_draft_tokens[i]): + _apply_bad_words_single_batch( + logits[start_idx + draft_idx], + bad_words_ids, + past_tokens_ids[start_idx + draft_idx], + ) + start_idx += num_draft_tokens[i] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 5f1dbf07d1f07..76555a8666857 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -8,6 +8,8 @@ import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts +from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -83,6 +85,14 @@ class RejectionSampler(nn.Module): A tensor containing the final output token IDs. """ assert metadata.max_spec_len <= MAX_SPEC_LEN + + # Use float32 for the target_logits. + target_logits = target_logits.to(torch.float32) + + target_logits = self.apply_logits_processors( + target_logits, sampling_metadata, metadata + ) + # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the # `compute_probs` function. @@ -131,6 +141,100 @@ class RejectionSampler(nn.Module): ] return outputs + def apply_logits_processors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + metadata: SpecDecodeMetadata, + ) -> torch.Tensor: + any_penalties_or_bad_words = ( + sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + ) + + output_token_ids = sampling_metadata.output_token_ids + if any_penalties_or_bad_words: + output_token_ids = self._combine_outputs_with_spec_tokens( + sampling_metadata.output_token_ids, + sampling_metadata.spec_token_ids, + ) + + # Calculate indices of target logits. + if ( + sampling_metadata.allowed_token_ids_mask is not None + or not sampling_metadata.no_penalties + ): + num_requests = len(sampling_metadata.output_token_ids) + num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") + original_indices = torch.arange(num_requests, device="cpu") + repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens) + repeat_indices = repeat_indices_cpu.to( + device=logits.device, non_blocking=True + ) + logits = self.apply_penalties( + logits, sampling_metadata, metadata, repeat_indices, output_token_ids + ) + + # Apply allowed token ids. + if sampling_metadata.allowed_token_ids_mask is not None: + token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices] + logits.masked_fill_(token_mask, float("-inf")) + + # Apply bad words exclusion. + if sampling_metadata.bad_words_token_ids: + apply_bad_words_with_drafts( + logits, + sampling_metadata.bad_words_token_ids, + output_token_ids, + metadata.num_draft_tokens, + ) + + return logits + + def apply_penalties( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + metadata: SpecDecodeMetadata, + repeat_indices: torch.Tensor, + output_token_ids: list[list[int]], + ) -> torch.Tensor: + if sampling_metadata.no_penalties: + return logits + + assert sampling_metadata.prompt_token_ids is not None + + prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices] + presence_penalties = sampling_metadata.presence_penalties[repeat_indices] + frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices] + repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices] + + logits = apply_all_penalties( + logits, + prompt_token_ids, + presence_penalties, + frequency_penalties, + repetition_penalties, + output_token_ids, + ) + return logits + + def _combine_outputs_with_spec_tokens( + self, + output_token_ids: list[list[int]], + spec_token_ids: Optional[list[list[int]]] = None, + ) -> list[list[int]]: + if spec_token_ids is None: + return output_token_ids + + result = [] + for out, spec in zip(output_token_ids, spec_token_ids): + if len(spec) == 0: + continue + result.append(out) + for i in range(len(spec) - 1): + result.append([*result[-1], spec[i]]) + return result + def rejection_sample( # [num_tokens] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d4d3fb029599e..101d2ebed4b75 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -70,6 +70,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + predict_bonus_token: bool = False, ) -> SamplerOutput: # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. @@ -84,18 +85,10 @@ class Sampler(nn.Module): # Use float32 for the logits. logits = logits.to(torch.float32) - # Apply allowed token ids. - logits = self.apply_allowed_token_ids(logits, sampling_metadata) - # Apply bad words exclusion. - logits = self.apply_bad_words(logits, sampling_metadata) - - # Apply logits processors which can impact greedy sampling - for processor in sampling_metadata.logitsprocs.non_argmax_invariant: - logits = processor.apply(logits) - - # Apply penalties (e.g., min_tokens, freq_penalties). - logits = self.apply_penalties(logits, sampling_metadata) + logits = self.apply_logits_processors( + logits, sampling_metadata, predict_bonus_token + ) # Sample the next token. sampled, processed_logprobs = self.sample(logits, sampling_metadata) if processed_logprobs is not None: @@ -245,10 +238,65 @@ class Sampler(nn.Module): return LogprobsTensors(indices, logprobs, token_ranks) + def _combine_outputs_with_spec_tokens( + self, + output_token_ids: list[list[int]], + spec_token_ids: Optional[list[list[int]]] = None, + ) -> list[list[int]]: + if spec_token_ids is None: + return output_token_ids + + return [ + [*out, *spec] if spec else out + for out, spec in zip(output_token_ids, spec_token_ids) + ] + + def apply_logits_processors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + predict_bonus_token: bool, + ) -> torch.Tensor: + any_penalties_or_bad_words = ( + sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + ) + + output_token_ids = sampling_metadata.output_token_ids + if predict_bonus_token and any_penalties_or_bad_words: + # Combine base outputs with spec tokens when speculative decoding + # is enabled. + output_token_ids = self._combine_outputs_with_spec_tokens( + sampling_metadata.output_token_ids, + sampling_metadata.spec_token_ids, + ) + + # Apply allowed token ids. + if sampling_metadata.allowed_token_ids_mask is not None: + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) + + # Apply bad words exclusion. + if sampling_metadata.bad_words_token_ids: + apply_bad_words( + logits, + sampling_metadata.bad_words_token_ids, + output_token_ids + if output_token_ids is not None + else sampling_metadata.output_token_ids, + ) + + # Apply logits processors which can impact greedy sampling. + for processor in sampling_metadata.logitsprocs.non_argmax_invariant: + logits = processor.apply(logits) + + # Apply penalties (e.g., freq_penalties). + logits = self.apply_penalties(logits, sampling_metadata, output_token_ids) + return logits + def apply_penalties( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + output_token_ids: Optional[list[list[int]]] = None, ) -> torch.Tensor: if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None @@ -258,28 +306,8 @@ class Sampler(nn.Module): sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, - sampling_metadata.output_token_ids, - ) - return logits - - def apply_allowed_token_ids( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - if sampling_metadata.allowed_token_ids_mask is not None: - logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) - return logits - - def apply_bad_words( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - if sampling_metadata.bad_words_token_ids: - apply_bad_words( - logits, - sampling_metadata.bad_words_token_ids, - sampling_metadata.output_token_ids, + output_token_ids + if output_token_ids is not None + else sampling_metadata.output_token_ids, ) return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 06f935423662e..22f5c6f7e6839 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -240,6 +240,9 @@ class InputBatch: # data structure self.logitsprocs = logitsprocs or LogitsProcessors() + # Store last speculative tokens for sampler. + self.spec_token_ids: list[Optional[list[int]]] = [] + # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -292,9 +295,11 @@ class InputBatch: if req_index == len(self._req_ids): self._req_ids.append(req_id) self.req_output_token_ids.append(request.output_token_ids) + self.spec_token_ids.append([]) else: self._req_ids[req_index] = req_id self.req_output_token_ids[req_index] = request.output_token_ids + self.spec_token_ids[req_index] = [] self.req_id_to_index[req_id] = req_index @@ -443,6 +448,7 @@ class InputBatch: self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + self.spec_token_ids[req_index] = None # LoRA lora_id = self.request_lora_mapping[req_index] @@ -486,6 +492,10 @@ class InputBatch: self.req_output_token_ids[i2], self.req_output_token_ids[i1], ) + self.spec_token_ids[i1], self.spec_token_ids[i2] = ( + self.spec_token_ids[i2], + self.spec_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( self.req_id_to_index[old_id_i2], @@ -601,6 +611,7 @@ class InputBatch: # The batched states are empty. self._req_ids.clear() self.req_output_token_ids.clear() + self.spec_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices @@ -629,6 +640,10 @@ class InputBatch: self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index + spec_token_ids = self.spec_token_ids[last_req_index] + self.spec_token_ids[empty_index] = spec_token_ids + self.spec_token_ids[last_req_index] = None + num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens @@ -700,6 +715,7 @@ class InputBatch: # Trim lists to the batch size. del self._req_ids[num_reqs:] del self.req_output_token_ids[num_reqs:] + del self.spec_token_ids[num_reqs:] def refresh_metadata(self): """Apply any batch updates to sampling metadata.""" @@ -784,6 +800,7 @@ class InputBatch: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(list[list[int]], self.req_output_token_ids), + spec_token_ids=cast(list[list[int]], self.spec_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=self.bad_words_token_ids, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bd799c06c0eb6..8d1940da566f9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -783,6 +783,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens + self.input_batch.spec_token_ids[req_index] = spec_token_ids # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -2197,6 +2198,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, + predict_bonus_token=True, ) bonus_token_ids = sampler_output.sampled_token_ids @@ -3491,6 +3493,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], + spec_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, logitsprocs=LogitsProcessors(),