[Sampler] Support returning final logprobs (#22387)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
22quinn 2025-08-20 21:28:32 -07:00 committed by GitHub
parent f64ee61d9e
commit f571ff8eb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 125 additions and 69 deletions

View File

@ -154,12 +154,15 @@ differences compared to V0:
##### Logprobs Calculation
Logprobs in V1 are now returned immediately once computed from the models raw output (i.e.
By default, logprobs in V1 are now returned immediately once computed from the models raw output (i.e.
before applying any logits post-processing such as temperature scaling or penalty
adjustments). As a result, the returned logprobs do not reflect the final adjusted
probabilities used during sampling.
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
You can adjust this behavior by setting the `--logprobs-mode` flag.
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including temperature and top_k/top_p.
##### Prompt Logprobs with Prefix Caching

View File

@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
assert len(logprob) == vocab_size
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if "logprobs" in logprobs_mode:
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
LogprobsMode.PROCESSED_LOGPROBS):
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode:
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
LogprobsMode.PROCESSED_LOGITS):
assert positive_values > 0
del llm

View File

@ -257,11 +257,16 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
MMEncoderTPMode = Literal["weights", "data"]
class LogprobsMode(enum.Enum):
RAW_LOGITS = "raw_logits"
RAW_LOGPROBS = "raw_logprobs"
PROCESSED_LOGITS = "processed_logits"
PROCESSED_LOGPROBS = "processed_logprobs"
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
@ -363,12 +368,13 @@ class ModelConfig:
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs"
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
Raw means the values before applying any logit processors, like bad words.
Processed means the values after applying all processors, including
temperature and top_k/top_p.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
@ -2586,7 +2592,7 @@ class MultiModalConfig:
skip_mm_profiling: bool = False
"""
When enabled, skips multimodal memory profiling and only profiles with
When enabled, skips multimodal memory profiling and only profiles with
language backbone model during engine initialization.
This reduces engine startup time but shifts the responsibility to users for
@ -2649,24 +2655,24 @@ class PoolerConfig:
## for embeddings models
normalize: Optional[bool] = None
"""
Whether to normalize the embeddings outputs.
Whether to normalize the embeddings outputs.
"""
dimensions: Optional[int] = None
"""
Reduce the dimensions of embeddings if model
Reduce the dimensions of embeddings if model
support matryoshka representation.
"""
## for classification models
activation: Optional[bool] = None
"""
Whether to apply activation function to the classification outputs.
Whether to apply activation function to the classification outputs.
"""
## for reward models
softmax: Optional[bool] = None
"""
Whether to apply softmax to the reward outputs.
Whether to apply softmax to the reward outputs.
"""
step_tag_id: Optional[int] = None
"""
@ -2692,9 +2698,9 @@ class PoolerConfig:
max_embed_len: Optional[int] = None
"""
Maximum input length allowed for embedding generation. When set, allows
Maximum input length allowed for embedding generation. When set, allows
inputs longer than max_embed_len to be accepted for embedding models.
This parameter enables accepting long inputs without requiring
This parameter enables accepting long inputs without requiring
VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds
max_embed_len, it will be handled according to the original max_model_len
validation logic. Defaults to None (i.e. set to max_model_len).

View File

@ -516,6 +516,7 @@ class EngineArgs:
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
choices=[f.value for f in LogprobsMode],
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])

View File

@ -8,6 +8,7 @@ import torch.nn as nn
from packaging import version
from vllm import envs
from vllm.config import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module):
Implementations may update the logits tensor in-place.
"""
def __init__(self):
def __init__(
self,
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
super().__init__()
if current_platform.is_cuda():
self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
LogprobsMode.PROCESSED_LOGPROBS
) and current_platform.is_cuda():
if is_flashinfer_available:
flashinfer_version = flashinfer.__version__
if version.parse(flashinfer_version) < version.parse("0.2.3"):
@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
if current_platform.is_tpu():
self.apply_top_k_top_p = apply_top_k_top_p_tpu
else:
self.apply_top_k_top_p = apply_top_k_top_p
def forward_native(
self,
@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
logits_to_return = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
return random_sample(probs, generators), logits_to_return
def forward_cuda(
self,
@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""More optimized implementation for top-k and top-p sampling."""
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
if (k is None and p is None) or generators:
if generators:
logger.warning_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in (
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
), "FlashInfer does not support returning logits/logprobs"
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)
def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
logits = apply_top_k_top_p_tpu(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def apply_top_k_top_p_tpu(

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A layer that samples the next tokens from the model's outputs."""
from typing import Optional
import torch
import torch.nn as nn
@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
"""
A layer that samples the next tokens from the model's outputs
with the following steps in order:
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
1. If logprobs are requested:
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
as the final logprobs to return.
b) If `logprobs_mode` is `raw_logits`, clone the logits
as the final logprobs to return.
2. Convert logits to float32.
3. Apply allowed token ids whitelist.
4. Apply bad words exclusion.
5. Apply logit processors which are not argmax-invariant,
i.e. that can impact greedy sampling.
a) Min tokens processor
b) Logit bias processor
6. Apply penalties
a) Repetition penalty
b) Frequency penalty
c) Presence penalty
7. Sample the next tokens. `sample` method performs the following steps:
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
return the greedily sampled tokens and final logprobs if requested.
b) Apply temperature.
c) Apply logit processors which are argmax-invariant, by default
the min_p processor.
d) Apply top_k and/or top_p.
e) Sample the next tokens with the probability distribution.
f) If `all_random` or temperature >= epsilon (1e-5), return the
randomly sampled tokens and final logprobs if requested. Else,
return the greedily sampled tokens and logprobs if requested.
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
(if requested). Note that if the sampled token is within the top
`max_num_logprobs`, the logprob will be eventually merged in
`LogprobsProcessor` during output processing. Therefore, the
final output may contain either `max_num_logprobs + 1` or
`max_num_logprobs` logprobs.
9. Return the final `SamplerOutput`.
"""
def __init__(self,
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
@ -34,13 +76,11 @@ class Sampler(nn.Module):
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs":
if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS:
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits":
elif self.logprobs_mode == LogprobsMode.RAW_LOGITS:
raw_logprobs = logits.clone()
# Use float32 for the logits.
@ -57,15 +97,10 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Get the process logprobs or logits.
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
if processed_logprobs is not None:
raw_logprobs = processed_logprobs
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
@ -105,7 +140,7 @@ class Sampler(nn.Module):
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
@ -119,7 +154,13 @@ class Sampler(nn.Module):
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
processed_logprobs = logits
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs
assert sampling_metadata.temperature is not None
@ -132,7 +173,7 @@ class Sampler(nn.Module):
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
@ -140,7 +181,7 @@ class Sampler(nn.Module):
)
if greedy_sampled is None:
return random_sampled
return random_sampled, processed_logprobs
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
@ -148,7 +189,7 @@ class Sampler(nn.Module):
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled
return sampled, processed_logprobs
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)

View File

@ -65,7 +65,7 @@ class Sampler(nn.Module):
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
random_sampled, _ = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,