[Sampler] Support returning all logprobs or logits (#21792)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn 2025-08-04 03:04:12 -07:00 committed by GitHub
parent fed5849d3f
commit 54de71d0df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 9 deletions

View File

@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
assert len(prompt_token_ids) == len(prompt_logprobs) assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return all vocabulary logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
@pytest.mark.parametrize( @pytest.mark.parametrize(
"logprobs_mode", "logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])

View File

@ -377,7 +377,8 @@ class ModelConfig:
max_logprobs: int = 20 max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is """Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API.""" OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
logprobs_mode: LogprobsMode = "raw_logprobs" logprobs_mode: LogprobsMode = "raw_logprobs"
"""Indicates the content returned in the logprobs and prompt_logprobs. """Indicates the content returned in the logprobs and prompt_logprobs.
Supported mode: Supported mode:
@ -1585,7 +1586,7 @@ class ModelConfig:
""" """
This method attempts to retrieve the non-default values of the This method attempts to retrieve the non-default values of the
generation config for this model. generation config for this model.
The generation config can contain information about special tokens, as The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`. to `get_diff_sampling_param`.
@ -2066,7 +2067,7 @@ class ParallelConfig:
and when data_parallel_size > 0. Enables running an AsyncLLM and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank.""" --data-parallel-start-rank."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""

View File

@ -156,6 +156,7 @@ class SamplingParams(
Note that the implementation follows the OpenAI API: The API will Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response. may be up to `logprobs+1` elements in the response.
When set to -1, return all `vocab_size` log probabilities.
prompt_logprobs: Number of log probabilities to return per prompt token. prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True. detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output. skip_special_tokens: Whether to skip special tokens in the output.
@ -414,9 +415,10 @@ class SamplingParams(
raise ValueError( raise ValueError(
f"min_tokens must be less than or equal to " f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.") f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0: if (self.logprobs is not None and self.logprobs != -1
and self.logprobs < 0):
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0: if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got " raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.") f"{self.prompt_logprobs}.")

View File

@ -138,7 +138,7 @@ class LogprobsProcessor:
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs """Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them. all prompt logprobs at once and then forgets them.
@ -176,7 +176,8 @@ class LogprobsProcessor:
Returns: Returns:
dict[token id, Logprob] dict[token id, Logprob]
""" """
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token # We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data # being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once. # into a dictionary twice is the same as doing it once.

View File

@ -65,8 +65,11 @@ class Processor:
params: SamplingParams, params: SamplingParams,
) -> None: ) -> None:
max_logprobs = self.model_config.max_logprobs max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
# Validate sample logprobs. # Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs: if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError( raise ValueError(
f"Requested sample logprobs of {params.logprobs}, " f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}") f"which is greater than max allowed: {max_logprobs}")

View File

@ -337,7 +337,9 @@ class InputBatch:
self.generators[req_index] = request.generator self.generators[req_index] = request.generator
if sampling_params.logprobs is not None: if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[ self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs req_id] = sampling_params.prompt_logprobs