[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov 2025-10-07 21:02:49 +01:00 committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 471 additions and 92 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import sys
from typing import Union
from typing import Any, Union
import pytest
@ -25,6 +25,7 @@ from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (
STR_POOLING_REJECTS_LOGITSPROCS,
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
LogitsProcessor,
)
@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch):
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
@pytest.mark.parametrize(
"logitproc_source",
[
@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch):
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
],
)
def test_pooling_rejects_custom_logitsprocs(
monkeypatch, logitproc_source: CustomLogitprocSource
def test_rejects_custom_logitsprocs(
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
):
"""Validate that vLLM engine initialization properly rejects custom
logitsprocs when the model is a pooling model.
logitsprocs when the model is a pooling model or speculative decoding
enabled.
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
logitproc is actually loaded.
@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs(
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
random.seed(40)
test_params: dict[str, dict[str, Any]] = {
"pooling": {
"runner": "pooling",
"model": POOLING_MODEL_NAME,
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
"speculative_config": None,
},
"spec_dec": {
"runner": "auto",
"model": MODEL_NAME,
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
},
}
config = test_params[model_scenario]
llm_kwargs: dict[str, Any] = {
"runner": config["runner"],
"model": config["model"],
"gpu_memory_utilization": 0.1,
"speculative_config": config["speculative_config"],
}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
# Scenario: vLLM loads a model and ignores a logitproc that is
# available at a preconfigured entrypoint
# Patch in dummy logitproc entrypoint
@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs(
# although they should ignore the entrypoint patch anyway
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
llm = LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
)
llm = LLM(**llm_kwargs)
# Require that no logitsprocs have been loaded
worker = llm.llm_engine.model_executor.driver_worker.worker
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
return
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object
kwargs["logits_processors"] = [DummyLogitsProcessor]
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
# Require that loading a pooling model alongside the logitproc raises
with pytest.raises(ValueError, match=config["error_message"]):
# Require that loading a model alongside the logitproc raises
# the appropriate exception.
LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
**kwargs,
)
LLM(**llm_kwargs)

View File

@ -6,6 +6,7 @@ import pytest
import torch
import torch.nn.functional as F
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
@ -21,7 +22,9 @@ def rejection_sampler():
def create_logits_tensor(
output_token_ids: list[list[int]], vocab_size: int = 100
output_token_ids: list[list[int]],
vocab_size: int = 100,
token_idx_to_override: Optional[int] = None,
) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
@ -33,15 +36,25 @@ def create_logits_tensor(
for j, token_id in enumerate(tokens):
logits[start_loc + j, token_id] = 100.0
start_loc += len(tokens)
if token_idx_to_override:
logits[:, token_idx_to_override] = 99.0
return logits
def create_sampling_metadata(
all_greedy: bool,
output_token_ids: Optional[list[list[int]]] = None,
prompt_token_ids: Optional[torch.Tensor] = None,
spec_token_ids: Optional[torch.Tensor] = None,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
frequency_penalties: Optional[list[float]] = None,
presence_penalties: Optional[list[float]] = None,
repetition_penalties: Optional[list[float]] = None,
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
allowed_token_ids_mask: Optional[torch.Tensor] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
@ -53,6 +66,21 @@ def create_sampling_metadata(
else:
assert temperature is not None
if any([frequency_penalties, presence_penalties, repetition_penalties]):
no_penalties = False
assert output_token_ids
assert len(output_token_ids) > 0
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
else:
no_penalties = True
frequency_penalties = torch.tensor([])
presence_penalties = torch.tensor([])
repetition_penalties = torch.tensor([])
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
@ -61,14 +89,15 @@ def create_sampling_metadata(
top_k=top_k,
generators=generators,
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
allowed_token_ids_mask=None,
bad_words_token_ids={},
no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties,
presence_penalties=presence_penalties,
repetition_penalties=repetition_penalties,
output_token_ids=[] if output_token_ids is None else output_token_ids,
spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
unmasked_indices=top_p_indices,
sampling_metadata=sampling_metadata,
)
########################### Tests for Logit Processors ###################
def test_frequency_penalties(rejection_sampler):
"""Test rejection sampling with frequency penalties"""
spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens
num_requsts = len(spec_tokens)
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
frequency_penalties=[1.5, 1.5, 0.7],
presence_penalties=[0.0] * num_requsts,
repetition_penalties=[1.0] * num_requsts,
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
def test_bad_words(rejection_sampler):
"""Test rejection sampling with bad words constraints"""
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
bad_words_token_ids={
0: [
[
2,
]
],
1: [
[
2,
]
],
# Do not apply bad words to the last request
},
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
def test_allowed_token_ids(rejection_sampler):
"""Test rejection sampling with allowed token ids"""
spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
# Not allowed tokens:
# 0: 0-4
# 1: 1-5
# 2: 2-6
num_allowed_token_ids = 5
# Use the token 15 as the sampler choose if a token rejected
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
batch_size = len(output_tokens)
_, vocab_size = logits.size()
mask = create_allowed_token_ids(
batch_size=batch_size,
vocab_size=vocab_size,
num_allowed_token_ids=num_allowed_token_ids,
device=logits.device,
)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[], [], []],
spec_token_ids=spec_tokens,
allowed_token_ids_mask=mask,
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)

View File

@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessors
@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor(
)
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask
def _create_bad_words_token_ids(
batch_size: int,
vocab_size: int,
@ -173,6 +152,7 @@ def _create_default_sampling_metadata(
prompt_token_ids, vocab_size, device
),
output_token_ids=output_token_ids,
spec_token_ids=[[] for _ in range(batch_size)],
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
@ -241,7 +221,9 @@ def test_sampler_presence_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
# Since all tokens initially have the same logits, the non-penalized
@ -293,7 +275,9 @@ def test_sampler_frequency_penalty(
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@ -343,7 +327,9 @@ def test_sampler_repetition_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids(
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
mask = _create_allowed_token_ids(
mask = create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids(
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
@ -444,7 +432,9 @@ def test_sampler_bad_words(
sampling_metadata, VOCAB_SIZE
)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]

View File

@ -215,3 +215,23 @@ def fake_apply_logitsprocs(
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits
def create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask

View File

@ -170,6 +170,7 @@ def _construct_expected_sampling_metadata(
repetition_penalties, dtype=torch.float, device=device
),
output_token_ids=output_token_ids,
spec_token_ids=[[] for _ in range(len(output_token_ids))],
no_penalties=(
all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)

View File

@ -37,6 +37,12 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
"Pooling models do not support custom logits processors."
)
# Error message when the user tries to initialize vLLM with a speculative
# decoding enabled and custom logitsproces
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
"Custom logits processors are not supportedwhen speculative decoding is enabled."
)
LOGITSPROCS_GROUP = "vllm.logits_processors"
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
@ -185,6 +191,17 @@ def build_logitsprocs(
" do not support logits processors."
)
return LogitsProcessors()
# Check if speculative decoding is enabled.
if vllm_config.speculative_config:
if custom_logitsprocs:
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
logger.warning(
"min_p, logit_bias, and min_tokens parameters won't currently work "
"with speculative decoding enabled."
)
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory)

View File

@ -40,3 +40,6 @@ class SamplingMetadata:
# Loaded logits processors
logitsprocs: LogitsProcessors
# Speculative token ids
spec_token_ids: Optional[list[list[int]]] = None

View File

@ -33,3 +33,20 @@ def apply_bad_words(
) -> None:
for i, bad_words_ids in bad_words_token_ids.items():
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
def apply_bad_words_with_drafts(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
num_draft_tokens: list[int],
) -> None:
start_idx = 0
for i, bad_words_ids in bad_words_token_ids.items():
for draft_idx in range(num_draft_tokens[i]):
_apply_bad_words_single_batch(
logits[start_idx + draft_idx],
bad_words_ids,
past_tokens_ids[start_idx + draft_idx],
)
start_idx += num_draft_tokens[i]

View File

@ -8,6 +8,8 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -83,6 +85,14 @@ class RejectionSampler(nn.Module):
A tensor containing the final output token IDs.
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# Use float32 for the target_logits.
target_logits = target_logits.to(torch.float32)
target_logits = self.apply_logits_processors(
target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
@ -131,6 +141,100 @@ class RejectionSampler(nn.Module):
]
return outputs
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if (
sampling_metadata.allowed_token_ids_mask is not None
or not sampling_metadata.no_penalties
):
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
repeat_indices = repeat_indices_cpu.to(
device=logits.device, non_blocking=True
)
logits = self.apply_penalties(
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids,
metadata.num_draft_tokens,
)
return logits
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
repeat_indices: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
logits = apply_all_penalties(
logits,
prompt_token_ids,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
return logits
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
result = []
for out, spec in zip(output_token_ids, spec_token_ids):
if len(spec) == 0:
continue
result.append(out)
for i in range(len(spec) - 1):
result.append([*result[-1], spec[i]])
return result
def rejection_sample(
# [num_tokens]

View File

@ -70,6 +70,7 @@ class Sampler(nn.Module):
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool = False,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
@ -84,18 +85,10 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
# Sample the next token.
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
if processed_logprobs is not None:
@ -245,10 +238,65 @@ class Sampler(nn.Module):
return LogprobsTensors(indices, logprobs, token_ranks)
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
return [
[*out, *spec] if spec else out
for out, spec in zip(output_token_ids, spec_token_ids)
]
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool,
) -> torch.Tensor:
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if predict_bonus_token and any_penalties_or_bad_words:
# Combine base outputs with spec tokens when speculative decoding
# is enabled.
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
sampling_metadata.spec_token_ids,
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids
if output_token_ids is not None
else sampling_metadata.output_token_ids,
)
# Apply logits processors which can impact greedy sampling.
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# Apply penalties (e.g., freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
return logits
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: Optional[list[list[int]]] = None,
) -> torch.Tensor:
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
@ -258,28 +306,8 @@ class Sampler(nn.Module):
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids,
)
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.allowed_token_ids_mask is not None:
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
return logits
def apply_bad_words(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
sampling_metadata.output_token_ids,
output_token_ids
if output_token_ids is not None
else sampling_metadata.output_token_ids,
)
return logits

View File

@ -240,6 +240,9 @@ class InputBatch:
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# Store last speculative tokens for sampler.
self.spec_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@ -292,9 +295,11 @@ class InputBatch:
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index] = []
self.req_id_to_index[req_id] = req_index
@ -443,6 +448,7 @@ class InputBatch:
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index] = None
# LoRA
lora_id = self.request_lora_mapping[req_index]
@ -486,6 +492,10 @@ class InputBatch:
self.req_output_token_ids[i2],
self.req_output_token_ids[i1],
)
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
self.req_id_to_index[old_id_i2],
@ -601,6 +611,7 @@ class InputBatch:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
@ -629,6 +640,10 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
spec_token_ids = self.spec_token_ids[last_req_index]
self.spec_token_ids[empty_index] = spec_token_ids
self.spec_token_ids[last_req_index] = None
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
@ -700,6 +715,7 @@ class InputBatch:
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
@ -784,6 +800,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,

View File

@ -783,6 +783,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
self.input_batch.spec_token_ids[req_index] = spec_token_ids
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@ -2197,6 +2198,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
predict_bonus_token=True,
)
bonus_token_ids = sampler_output.sampled_token_ids
@ -3491,6 +3493,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
spec_token_ids=[[] for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessors(),