From accac82928477f87e1082ba501c2d43622556275 Mon Sep 17 00:00:00 2001 From: Lu Fang <30275821+houseroad@users.noreply.github.com> Date: Wed, 23 Jul 2025 01:39:25 -0700 Subject: [PATCH] [Sampler] Introduce logprobs mode for logging (#21398) Signed-off-by: Lu Fang --- tests/v1/sample/test_logprobs.py | 43 ++++++++++++++++++++++++++++++ vllm/config.py | 9 +++++++ vllm/engine/arg_utils.py | 18 ++++++++----- vllm/v1/sample/sampler.py | 17 ++++++++++-- vllm/v1/sample/tpu/sampler.py | 1 + vllm/v1/worker/gpu_input_batch.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 4 +-- 7 files changed, 83 insertions(+), 13 deletions(-) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 4f1f340a4ccb..680e2ce98bb2 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -12,6 +12,7 @@ from tests.v1.sample.utils import ( assert_incr_detok_str_matches_non_incr_detok_str, compute_correct_cumulative_logprob, get_test_batch) from vllm import SamplingParams +from vllm.config import LogprobsMode from ...conftest import HfRunner, VllmRunner @@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts, # prompt token assert prompt_logprobs is not None assert len(prompt_token_ids) == len(prompt_logprobs) + + +@pytest.mark.parametrize( + "logprobs_mode", + ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +def test_logprobs_mode(logprobs_mode: LogprobsMode, + monkeypatch: pytest.MonkeyPatch): + """Test with LLM engine with different logprobs_mode. + For logprobs, we should have non-positive values. + For logits, we should expect at least one positive values. + """ + from vllm import LLM + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + "facebook/opt-125m", + max_logprobs=5, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.05, + max_model_len=16, + logprobs_mode=logprobs_mode) + vllm_sampling_params = SamplingParams(logprobs=1) + results = llm.generate(["Hello world"], + sampling_params=vllm_sampling_params) + + total_token_with_logprobs = 0 + positive_values = 0 + for output in results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + logprob = logprobs[token_id] + if "logprobs" in logprobs_mode: + assert logprob.logprob <= 0 + if logprob.logprob > 0: + positive_values = positive_values + 1 + total_token_with_logprobs = total_token_with_logprobs + 1 + assert total_token_with_logprobs >= len(results[0].outputs) + if "logits" in logprobs_mode: + assert positive_values > 0 + del llm diff --git a/vllm/config.py b/vllm/config.py index 6623a48f839a..223c1968c275 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] +LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", + "processed_logits"] @config @@ -316,6 +318,13 @@ class ModelConfig: """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API.""" + logprobs_mode: LogprobsMode = "raw_logprobs" + """Indicates the content returned in the logprobs and prompt_logprobs. + Supported mode: + 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. + Raw means the values before applying logit processors, like bad words. + Processed means the values after applying such processors. + """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding window functionality of the model, capping to sliding window size. If the diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1e3d46a8d96e..4a5efd40241d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DetailedTraceModules, Device, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelDType, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, PromptAdapterConfig, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + KVTransferConfig, LoadConfig, LoadFormat, + LogprobsMode, LoRAConfig, ModelConfig, ModelDType, + ModelImpl, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, + PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -324,6 +324,7 @@ class EngineArgs: SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs + logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False revision: Optional[str] = ModelConfig.revision code_revision: Optional[str] = ModelConfig.code_revision @@ -490,6 +491,8 @@ class EngineArgs: **model_kwargs["max_seq_len_to_capture"]) model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", + **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) model_group.add_argument("--disable-cascade-attn", @@ -892,6 +895,7 @@ class EngineArgs: enforce_eager=self.enforce_eager, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, + logprobs_mode=self.logprobs_mode, disable_sliding_window=self.disable_sliding_window, disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index fa078e628768..82f51298f1b5 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +from vllm.config import LogprobsMode from vllm.utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): - def __init__(self): + def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): super().__init__() self.topk_topp_sampler = TopKTopPSampler() self.pin_memory = is_pin_memory_available() + self.logprobs_mode = logprobs_mode def forward( self, @@ -36,7 +38,10 @@ class Sampler(nn.Module): # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - raw_logprobs = self.compute_logprobs(logits) + if self.logprobs_mode == "raw_logprobs": + raw_logprobs = self.compute_logprobs(logits) + elif self.logprobs_mode == "raw_logits": + raw_logprobs = logits.clone() # Use float32 for the logits. logits = logits.to(torch.float32) @@ -51,6 +56,14 @@ class Sampler(nn.Module): # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) + + # Get the process logprobs or logits. + if num_logprobs is not None: + if self.logprobs_mode == "processed_logprobs": + raw_logprobs = self.compute_logprobs(logits) + elif self.logprobs_mode == "processed_logits": + raw_logprobs = logits.clone() + # Sample the next token. sampled = self.sample(logits, sampling_metadata) # Convert sampled token ids to int64 (long) type to ensure compatibility diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 1056eb1d7b7f..2c9f4892bc24 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): def __init__(self): + # TODO(houseroad): Add support for logprobs_mode. super().__init__() self.topk_topp_sampler = TopKTopPSampler() diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a242c7fca5ef..c63041600f38 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -389,7 +389,7 @@ class InputBatch: def remove_request(self, req_id: str) -> Optional[int]: """This method must always be followed by a call to condense(). - + Args: req_id: request to remove @@ -590,7 +590,7 @@ class InputBatch: def refresh_metadata(self): """Apply batch updates, reset input batch at end of step - + * Apply batch add/remove/permute to logits procs' states * If batch state is modified, update sampling metadata """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c14ac3be3c0..6a42e01f14b0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.encoder_cache_size = encoder_cache_size # Sampler - self.sampler = Sampler() + self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: Optional[EplbState] = None """ @@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - - during DP rank dummy run + - during DP rank dummy run """ dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1