[Sampler] Introduce logprobs mode for logging (#21398)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang 2025-07-23 01:39:25 -07:00 committed by GitHub
parent 23637dcdef
commit accac82928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 83 additions and 13 deletions

View File

@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams
from vllm.config import LogprobsMode
from ...conftest import HfRunner, VllmRunner
@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
"facebook/opt-125m",
max_logprobs=5,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0
for output in results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if "logprobs" in logprobs_mode:
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode:
assert positive_values > 0
del llm

View File

@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
@config
@ -316,6 +318,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the

View File

@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
KVTransferConfig, LoadConfig, LoadFormat,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -324,6 +324,7 @@ class EngineArgs:
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False
revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = ModelConfig.code_revision
@ -490,6 +491,8 @@ class EngineArgs:
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn",
@ -892,6 +895,7 @@ class EngineArgs:
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,

View File

@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from vllm.config import LogprobsMode
from vllm.utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
def forward(
self,
@ -36,7 +38,10 @@ class Sampler(nn.Module):
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
# Use float32 for the logits.
logits = logits.to(torch.float32)
@ -51,6 +56,14 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Get the process logprobs or logits.
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility

View File

@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
# TODO(houseroad): Add support for logprobs_mode.
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()

View File

@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache_size = encoder_cache_size
# Sampler
self.sampler = Sampler()
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: Optional[EplbState] = None
"""