[Sampler] Introduce logprobs mode for logging (#21398)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang 2025-07-23 01:39:25 -07:00 committed by GitHub
parent 23637dcdef
commit accac82928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 83 additions and 13 deletions

View File

@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str, assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch) compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import LogprobsMode
from ...conftest import HfRunner, VllmRunner from ...conftest import HfRunner, VllmRunner
@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token # prompt token
assert prompt_logprobs is not None assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs) assert len(prompt_token_ids) == len(prompt_logprobs)
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
"facebook/opt-125m",
max_logprobs=5,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0
for output in results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if "logprobs" in logprobs_mode:
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode:
assert positive_values > 0
del llm

View File

@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
TokenizerMode = Literal["auto", "slow", "mistral", "custom"] TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
@config @config
@ -316,6 +318,13 @@ class ModelConfig:
"""Maximum number of log probabilities to return when `logprobs` is """Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API.""" OpenAI Chat Completions API."""
logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode:
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
Raw means the values before applying logit processors, like bad words.
Processed means the values after applying such processors.
"""
disable_sliding_window: bool = False disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding """Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the window functionality of the model, capping to sliding window size. If the

View File

@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig, DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend, DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat,
ModelConfig, ModelDType, ModelImpl, MultiModalConfig, LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ObservabilityConfig, ParallelConfig, PoolerConfig, ModelImpl, MultiModalConfig, ObservabilityConfig,
PrefixCachingHashAlgo, PromptAdapterConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig, PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs, SpeculativeConfig, TaskOption, TokenizerMode,
get_field) VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
@ -324,6 +324,7 @@ class EngineArgs:
SchedulerConfig.long_prefill_token_threshold SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = ModelConfig.max_logprobs max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = ModelConfig.revision revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = ModelConfig.code_revision code_revision: Optional[str] = ModelConfig.code_revision
@ -490,6 +491,8 @@ class EngineArgs:
**model_kwargs["max_seq_len_to_capture"]) **model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs", model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"]) **model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
**model_kwargs["logprobs_mode"])
model_group.add_argument("--disable-sliding-window", model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"]) **model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn", model_group.add_argument("--disable-cascade-attn",
@ -892,6 +895,7 @@ class EngineArgs:
enforce_eager=self.enforce_eager, enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs, max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn, disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,

View File

@ -5,6 +5,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import LogprobsMode
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
super().__init__() super().__init__()
self.topk_topp_sampler = TopKTopPSampler() self.topk_topp_sampler = TopKTopPSampler()
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.logprobs_mode = logprobs_mode
def forward( def forward(
self, self,
@ -36,7 +38,10 @@ class Sampler(nn.Module):
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None: if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits) if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
# Use float32 for the logits. # Use float32 for the logits.
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
@ -51,6 +56,14 @@ class Sampler(nn.Module):
# Apply penalties (e.g., min_tokens, freq_penalties). # Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata) logits = self.apply_penalties(logits, sampling_metadata)
# Get the process logprobs or logits.
if num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "processed_logits":
raw_logprobs = logits.clone()
# Sample the next token. # Sample the next token.
sampled = self.sample(logits, sampling_metadata) sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility # Convert sampled token ids to int64 (long) type to ensure compatibility

View File

@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self):
# TODO(houseroad): Add support for logprobs_mode.
super().__init__() super().__init__()
self.topk_topp_sampler = TopKTopPSampler() self.topk_topp_sampler = TopKTopPSampler()

View File

@ -389,7 +389,7 @@ class InputBatch:
def remove_request(self, req_id: str) -> Optional[int]: def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense(). """This method must always be followed by a call to condense().
Args: Args:
req_id: request to remove req_id: request to remove
@ -590,7 +590,7 @@ class InputBatch:
def refresh_metadata(self): def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step """Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states * Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata * If batch state is modified, update sampling metadata
""" """

View File

@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache_size = encoder_cache_size self.encoder_cache_size = encoder_cache_size
# Sampler # Sampler
self.sampler = Sampler() self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: Optional[EplbState] = None self.eplb_state: Optional[EplbState] = None
""" """
@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection This is to help balance expert-selection
- during profile_run - during profile_run
- during DP rank dummy run - during DP rank dummy run
""" """
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1