From 54de71d0dfbb6340fdbc620f4ebeb4236d165a37 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Mon, 4 Aug 2025 03:04:12 -0700 Subject: [PATCH] [Sampler] Support returning all logprobs or logits (#21792) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- tests/v1/sample/test_logprobs.py | 27 +++++++++++++++++++++++++++ vllm/config.py | 7 ++++--- vllm/sampling_params.py | 6 ++++-- vllm/v1/engine/logprobs.py | 5 +++-- vllm/v1/engine/processor.py | 5 ++++- vllm/v1/worker/gpu_input_batch.py | 4 +++- 6 files changed, 45 insertions(+), 9 deletions(-) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 680e2ce98bb27..8bd142e87b06e 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts, assert len(prompt_token_ids) == len(prompt_logprobs) +def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): + """Engine should return all vocabulary logprobs + + Args: + example_prompts: list of example prompts (test fixture) + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=-1, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.15, + max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, + logprobs=-1) + results_logprobs_all = runner.llm.generate( + example_prompts, sampling_params=sampling_params_logprobs_all) + vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): + logprobs = results_logprobs_all[i].outputs[0].logprobs + assert logprobs is not None + for logprob in logprobs: + assert len(logprob) == vocab_size + + @pytest.mark.parametrize( "logprobs_mode", ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) diff --git a/vllm/config.py b/vllm/config.py index 871df455ef58f..5c300e327397b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -377,7 +377,8 @@ class ModelConfig: max_logprobs: int = 20 """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the - OpenAI Chat Completions API.""" + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * + vocab_size) logprobs are allowed to be returned and it may cause OOM.""" logprobs_mode: LogprobsMode = "raw_logprobs" """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: @@ -1585,7 +1586,7 @@ class ModelConfig: """ This method attempts to retrieve the non-default values of the generation config for this model. - + The generation config can contain information about special tokens, as well as sampling parameters. Which is why this method exists separately to `get_diff_sampling_param`. @@ -2066,7 +2067,7 @@ class ParallelConfig: and when data_parallel_size > 0. Enables running an AsyncLLM and API server on a "per-node" basis where vLLM load balances between local data parallel ranks, but an external LB balances - between vLLM nodes/replicas. Set explicitly in conjunction with + between vLLM nodes/replicas. Set explicitly in conjunction with --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 322e53b753948..52e4cbd096153 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -156,6 +156,7 @@ class SamplingParams( Note that the implementation follows the OpenAI API: The API will always return the log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. + When set to -1, return all `vocab_size` log probabilities. prompt_logprobs: Number of log probabilities to return per prompt token. detokenize: Whether to detokenize the output. Defaults to True. skip_special_tokens: Whether to skip special tokens in the output. @@ -414,9 +415,10 @@ class SamplingParams( raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if self.logprobs is not None and self.logprobs < 0: + if (self.logprobs is not None and self.logprobs != -1 + and self.logprobs < 0): raise ValueError( - f"logprobs must be non-negative, got {self.logprobs}.") + f"logprobs must be non-negative or -1, got {self.logprobs}.") if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index e95da0a5e5aaf..3de7fa6889e55 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -138,7 +138,7 @@ class LogprobsProcessor: def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: """Pop and return all request prompt logprobs - + The logprobs processor aggregates prompt chunk logprobs over one or more prefill chunks. This method returns all prompt logprobs at once and then forgets them. @@ -176,7 +176,8 @@ class LogprobsProcessor: Returns: dict[token id, Logprob] """ - + if num_logprobs == -1: + num_logprobs = len(logprobs) # We do not need a special case for the sampled token # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 224acc47feb27..692a7dd5640e0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -65,8 +65,11 @@ class Processor: params: SamplingParams, ) -> None: max_logprobs = self.model_config.max_logprobs + if max_logprobs == -1: + return # Validate sample logprobs. - if params.logprobs and params.logprobs > max_logprobs: + if params.logprobs and (params.logprobs == -1 + or params.logprobs > max_logprobs): raise ValueError( f"Requested sample logprobs of {params.logprobs}, " f"which is greater than max allowed: {max_logprobs}") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c63041600f388..d9d0b4bec871a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -337,7 +337,9 @@ class InputBatch: self.generators[req_index] = request.generator if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs + self.num_logprobs[req_id] = (self.vocab_size + if sampling_params.logprobs == -1 + else sampling_params.logprobs) if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[ req_id] = sampling_params.prompt_logprobs