mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[Sampler] Support returning all logprobs or logits (#21792)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
fed5849d3f
commit
54de71d0df
@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
|||||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Engine should return all vocabulary logprobs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
example_prompts: list of example prompts (test fixture)
|
||||||
|
"""
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
runner = VllmRunner(
|
||||||
|
"facebook/opt-125m",
|
||||||
|
max_logprobs=-1,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
# 2 other llms alive during whole session
|
||||||
|
gpu_memory_utilization=0.15,
|
||||||
|
max_model_len=256)
|
||||||
|
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||||
|
logprobs=-1)
|
||||||
|
results_logprobs_all = runner.llm.generate(
|
||||||
|
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||||
|
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||||
|
for i in range(len(results_logprobs_all)):
|
||||||
|
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||||
|
assert logprobs is not None
|
||||||
|
for logprob in logprobs:
|
||||||
|
assert len(logprob) == vocab_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"logprobs_mode",
|
"logprobs_mode",
|
||||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||||
|
|||||||
@ -377,7 +377,8 @@ class ModelConfig:
|
|||||||
max_logprobs: int = 20
|
max_logprobs: int = 20
|
||||||
"""Maximum number of log probabilities to return when `logprobs` is
|
"""Maximum number of log probabilities to return when `logprobs` is
|
||||||
specified in `SamplingParams`. The default value comes the default for the
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
OpenAI Chat Completions API."""
|
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
|
||||||
|
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
|
||||||
logprobs_mode: LogprobsMode = "raw_logprobs"
|
logprobs_mode: LogprobsMode = "raw_logprobs"
|
||||||
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||||
Supported mode:
|
Supported mode:
|
||||||
@ -1585,7 +1586,7 @@ class ModelConfig:
|
|||||||
"""
|
"""
|
||||||
This method attempts to retrieve the non-default values of the
|
This method attempts to retrieve the non-default values of the
|
||||||
generation config for this model.
|
generation config for this model.
|
||||||
|
|
||||||
The generation config can contain information about special tokens, as
|
The generation config can contain information about special tokens, as
|
||||||
well as sampling parameters. Which is why this method exists separately
|
well as sampling parameters. Which is why this method exists separately
|
||||||
to `get_diff_sampling_param`.
|
to `get_diff_sampling_param`.
|
||||||
@ -2066,7 +2067,7 @@ class ParallelConfig:
|
|||||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||||
and API server on a "per-node" basis where vLLM load balances
|
and API server on a "per-node" basis where vLLM load balances
|
||||||
between local data parallel ranks, but an external LB balances
|
between local data parallel ranks, but an external LB balances
|
||||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||||
--data-parallel-start-rank."""
|
--data-parallel-start-rank."""
|
||||||
enable_expert_parallel: bool = False
|
enable_expert_parallel: bool = False
|
||||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||||
|
|||||||
@ -156,6 +156,7 @@ class SamplingParams(
|
|||||||
Note that the implementation follows the OpenAI API: The API will
|
Note that the implementation follows the OpenAI API: The API will
|
||||||
always return the log probability of the sampled token, so there
|
always return the log probability of the sampled token, so there
|
||||||
may be up to `logprobs+1` elements in the response.
|
may be up to `logprobs+1` elements in the response.
|
||||||
|
When set to -1, return all `vocab_size` log probabilities.
|
||||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||||
detokenize: Whether to detokenize the output. Defaults to True.
|
detokenize: Whether to detokenize the output. Defaults to True.
|
||||||
skip_special_tokens: Whether to skip special tokens in the output.
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
@ -414,9 +415,10 @@ class SamplingParams(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"min_tokens must be less than or equal to "
|
f"min_tokens must be less than or equal to "
|
||||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||||
if self.logprobs is not None and self.logprobs < 0:
|
if (self.logprobs is not None and self.logprobs != -1
|
||||||
|
and self.logprobs < 0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
||||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||||
f"{self.prompt_logprobs}.")
|
f"{self.prompt_logprobs}.")
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class LogprobsProcessor:
|
|||||||
|
|
||||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||||
"""Pop and return all request prompt logprobs
|
"""Pop and return all request prompt logprobs
|
||||||
|
|
||||||
The logprobs processor aggregates prompt chunk logprobs
|
The logprobs processor aggregates prompt chunk logprobs
|
||||||
over one or more prefill chunks. This method returns
|
over one or more prefill chunks. This method returns
|
||||||
all prompt logprobs at once and then forgets them.
|
all prompt logprobs at once and then forgets them.
|
||||||
@ -176,7 +176,8 @@ class LogprobsProcessor:
|
|||||||
Returns:
|
Returns:
|
||||||
dict[token id, Logprob]
|
dict[token id, Logprob]
|
||||||
"""
|
"""
|
||||||
|
if num_logprobs == -1:
|
||||||
|
num_logprobs = len(logprobs)
|
||||||
# We do not need a special case for the sampled token
|
# We do not need a special case for the sampled token
|
||||||
# being in the topk, since inserting duplicated data
|
# being in the topk, since inserting duplicated data
|
||||||
# into a dictionary twice is the same as doing it once.
|
# into a dictionary twice is the same as doing it once.
|
||||||
|
|||||||
@ -65,8 +65,11 @@ class Processor:
|
|||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
max_logprobs = self.model_config.max_logprobs
|
max_logprobs = self.model_config.max_logprobs
|
||||||
|
if max_logprobs == -1:
|
||||||
|
return
|
||||||
# Validate sample logprobs.
|
# Validate sample logprobs.
|
||||||
if params.logprobs and params.logprobs > max_logprobs:
|
if params.logprobs and (params.logprobs == -1
|
||||||
|
or params.logprobs > max_logprobs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested sample logprobs of {params.logprobs}, "
|
f"Requested sample logprobs of {params.logprobs}, "
|
||||||
f"which is greater than max allowed: {max_logprobs}")
|
f"which is greater than max allowed: {max_logprobs}")
|
||||||
|
|||||||
@ -337,7 +337,9 @@ class InputBatch:
|
|||||||
self.generators[req_index] = request.generator
|
self.generators[req_index] = request.generator
|
||||||
|
|
||||||
if sampling_params.logprobs is not None:
|
if sampling_params.logprobs is not None:
|
||||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
self.num_logprobs[req_id] = (self.vocab_size
|
||||||
|
if sampling_params.logprobs == -1
|
||||||
|
else sampling_params.logprobs)
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
self.num_prompt_logprobs[
|
self.num_prompt_logprobs[
|
||||||
req_id] = sampling_params.prompt_logprobs
|
req_id] = sampling_params.prompt_logprobs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user