mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Sampler] Support returning all logprobs or logits (#21792)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
fed5849d3f
commit
54de71d0df
@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
|
||||
|
||||
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return all vocabulary logprobs
|
||||
|
||||
Args:
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=-1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||
logprobs=-1)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"logprobs_mode",
|
||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||
|
||||
@ -377,7 +377,8 @@ class ModelConfig:
|
||||
max_logprobs: int = 20
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
OpenAI Chat Completions API."""
|
||||
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
|
||||
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs"
|
||||
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||
Supported mode:
|
||||
@ -1585,7 +1586,7 @@ class ModelConfig:
|
||||
"""
|
||||
This method attempts to retrieve the non-default values of the
|
||||
generation config for this model.
|
||||
|
||||
|
||||
The generation config can contain information about special tokens, as
|
||||
well as sampling parameters. Which is why this method exists separately
|
||||
to `get_diff_sampling_param`.
|
||||
@ -2066,7 +2067,7 @@ class ParallelConfig:
|
||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||
and API server on a "per-node" basis where vLLM load balances
|
||||
between local data parallel ranks, but an external LB balances
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
--data-parallel-start-rank."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
|
||||
@ -156,6 +156,7 @@ class SamplingParams(
|
||||
Note that the implementation follows the OpenAI API: The API will
|
||||
always return the log probability of the sampled token, so there
|
||||
may be up to `logprobs+1` elements in the response.
|
||||
When set to -1, return all `vocab_size` log probabilities.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
@ -414,9 +415,10 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
if (self.logprobs is not None and self.logprobs != -1
|
||||
and self.logprobs < 0):
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
|
||||
@ -138,7 +138,7 @@ class LogprobsProcessor:
|
||||
|
||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
@ -176,7 +176,8 @@ class LogprobsProcessor:
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
|
||||
@ -65,8 +65,11 @@ class Processor:
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
return
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs and params.logprobs > max_logprobs:
|
||||
if params.logprobs and (params.logprobs == -1
|
||||
or params.logprobs > max_logprobs):
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {params.logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
@ -337,7 +337,9 @@ class InputBatch:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
self.num_logprobs[req_id] = (self.vocab_size
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[
|
||||
req_id] = sampling_params.prompt_logprobs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user