From b3d7e3c845d3f4905e9a95776c7d605c743f0ea1 Mon Sep 17 00:00:00 2001 From: Xingyu Liu <38244988+charlotte12l@users.noreply.github.com> Date: Sun, 7 Sep 2025 19:34:31 -0700 Subject: [PATCH] [Sampler] Support returning all prompt logprobs (#23868) Signed-off-by: Xingyu Liu Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Cyrus Leung --- tests/v1/sample/test_logprobs.py | 12 ++++++++++-- vllm/sampling_params.py | 11 +++++++---- vllm/v1/engine/processor.py | 28 ++++++++++++++++++---------- vllm/v1/worker/gpu_input_batch.py | 5 +++-- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index e835c029634c..570e330208a3 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts, def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): - """Engine should return all vocabulary logprobs + """Engine should return all vocabulary logprobs and prompt logprobs Args: example_prompts: list of example prompts (test fixture) @@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): # 2 other llms alive during whole session gpu_memory_utilization=0.15, max_model_len=256) + sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1) + logprobs=-1, + prompt_logprobs=-1) results_logprobs_all = runner.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_all) vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + for i in range(len(results_logprobs_all)): logprobs = results_logprobs_all[i].outputs[0].logprobs + prompt_logprobs = results_logprobs_all[i].prompt_logprobs assert logprobs is not None for logprob in logprobs: assert len(logprob) == vocab_size + assert prompt_logprobs is not None + assert prompt_logprobs[0] is None + for prompt_logprob in prompt_logprobs[1:]: + assert len(prompt_logprob) == vocab_size @pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c7b4ba34c602..fe93e906064e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -165,7 +165,8 @@ class SamplingParams( the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None - """Number of log probabilities to return per prompt token.""" + """Number of log probabilities to return per prompt token. + When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. @@ -409,9 +410,11 @@ class SamplingParams( and self.logprobs < 0): raise ValueError( f"logprobs must be non-negative or -1, got {self.logprobs}.") - if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0): + raise ValueError( + f"prompt_logprobs must be non-negative or -1, got " + f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None and (self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1)): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 1aa117ded4ed..baade243140d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -65,19 +65,27 @@ class Processor: ) -> None: max_logprobs = self.model_config.max_logprobs if max_logprobs == -1: - return + max_logprobs = self.model_config.get_vocab_size() + # Validate sample logprobs. - if params.logprobs and (params.logprobs == -1 - or params.logprobs > max_logprobs): - raise ValueError( - f"Requested sample logprobs of {params.logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.logprobs: + num_logprobs = params.logprobs + if num_logprobs == -1: + num_logprobs = self.model_config.get_vocab_size() + if num_logprobs > max_logprobs: + raise ValueError( + f"Requested sample logprobs of {num_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") # Validate prompt logprobs. - if params.prompt_logprobs and params.prompt_logprobs > max_logprobs: - raise ValueError( - f"Requested prompt logprobs of {params.prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + if params.prompt_logprobs: + num_prompt_logprobs = params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.model_config.get_vocab_size() + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Requested prompt logprobs of {num_prompt_logprobs}, " + f"which is is greater than max allowed: {max_logprobs}") def _validate_sampling_params( self, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 83fc821b8494..bf9b16575e60 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -360,8 +360,9 @@ class InputBatch: if sampling_params.logprobs == -1 else sampling_params.logprobs) if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs + self.num_prompt_logprobs[req_id] = ( + self.vocab_size if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id)