mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:45:01 +08:00
[Sampler] Support returning all prompt logprobs (#23868)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
67841317d1
commit
b3d7e3c845
@ -430,7 +430,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
|||||||
|
|
||||||
|
|
||||||
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Engine should return all vocabulary logprobs
|
"""Engine should return all vocabulary logprobs and prompt logprobs
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
example_prompts: list of example prompts (test fixture)
|
example_prompts: list of example prompts (test fixture)
|
||||||
@ -444,16 +444,24 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
|||||||
# 2 other llms alive during whole session
|
# 2 other llms alive during whole session
|
||||||
gpu_memory_utilization=0.15,
|
gpu_memory_utilization=0.15,
|
||||||
max_model_len=256)
|
max_model_len=256)
|
||||||
|
|
||||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||||
logprobs=-1)
|
logprobs=-1,
|
||||||
|
prompt_logprobs=-1)
|
||||||
results_logprobs_all = runner.llm.generate(
|
results_logprobs_all = runner.llm.generate(
|
||||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||||
|
|
||||||
for i in range(len(results_logprobs_all)):
|
for i in range(len(results_logprobs_all)):
|
||||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||||
|
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||||
assert logprobs is not None
|
assert logprobs is not None
|
||||||
for logprob in logprobs:
|
for logprob in logprobs:
|
||||||
assert len(logprob) == vocab_size
|
assert len(logprob) == vocab_size
|
||||||
|
assert prompt_logprobs is not None
|
||||||
|
assert prompt_logprobs[0] is None
|
||||||
|
for prompt_logprob in prompt_logprobs[1:]:
|
||||||
|
assert len(prompt_logprob) == vocab_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
||||||
|
|||||||
@ -165,7 +165,8 @@ class SamplingParams(
|
|||||||
the sampled token, so there may be up to `logprobs+1` elements in the
|
the sampled token, so there may be up to `logprobs+1` elements in the
|
||||||
response. When set to -1, return all `vocab_size` log probabilities."""
|
response. When set to -1, return all `vocab_size` log probabilities."""
|
||||||
prompt_logprobs: Optional[int] = None
|
prompt_logprobs: Optional[int] = None
|
||||||
"""Number of log probabilities to return per prompt token."""
|
"""Number of log probabilities to return per prompt token.
|
||||||
|
When set to -1, return all `vocab_size` log probabilities."""
|
||||||
# NOTE: This parameter is only exposed at the engine level for now.
|
# NOTE: This parameter is only exposed at the engine level for now.
|
||||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||||
# not support returning only a list of token IDs.
|
# not support returning only a list of token IDs.
|
||||||
@ -409,9 +410,11 @@ class SamplingParams(
|
|||||||
and self.logprobs < 0):
|
and self.logprobs < 0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
||||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
if (self.prompt_logprobs is not None and self.prompt_logprobs != -1
|
||||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
and self.prompt_logprobs < 0):
|
||||||
f"{self.prompt_logprobs}.")
|
raise ValueError(
|
||||||
|
f"prompt_logprobs must be non-negative or -1, got "
|
||||||
|
f"{self.prompt_logprobs}.")
|
||||||
if (self.truncate_prompt_tokens is not None
|
if (self.truncate_prompt_tokens is not None
|
||||||
and (self.truncate_prompt_tokens == 0
|
and (self.truncate_prompt_tokens == 0
|
||||||
or self.truncate_prompt_tokens < -1)):
|
or self.truncate_prompt_tokens < -1)):
|
||||||
|
|||||||
@ -65,19 +65,27 @@ class Processor:
|
|||||||
) -> None:
|
) -> None:
|
||||||
max_logprobs = self.model_config.max_logprobs
|
max_logprobs = self.model_config.max_logprobs
|
||||||
if max_logprobs == -1:
|
if max_logprobs == -1:
|
||||||
return
|
max_logprobs = self.model_config.get_vocab_size()
|
||||||
|
|
||||||
# Validate sample logprobs.
|
# Validate sample logprobs.
|
||||||
if params.logprobs and (params.logprobs == -1
|
if params.logprobs:
|
||||||
or params.logprobs > max_logprobs):
|
num_logprobs = params.logprobs
|
||||||
raise ValueError(
|
if num_logprobs == -1:
|
||||||
f"Requested sample logprobs of {params.logprobs}, "
|
num_logprobs = self.model_config.get_vocab_size()
|
||||||
f"which is greater than max allowed: {max_logprobs}")
|
if num_logprobs > max_logprobs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested sample logprobs of {num_logprobs}, "
|
||||||
|
f"which is is greater than max allowed: {max_logprobs}")
|
||||||
|
|
||||||
# Validate prompt logprobs.
|
# Validate prompt logprobs.
|
||||||
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
if params.prompt_logprobs:
|
||||||
raise ValueError(
|
num_prompt_logprobs = params.prompt_logprobs
|
||||||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
if num_prompt_logprobs == -1:
|
||||||
f"which is greater than max allowed: {max_logprobs}")
|
num_prompt_logprobs = self.model_config.get_vocab_size()
|
||||||
|
if num_prompt_logprobs > max_logprobs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested prompt logprobs of {num_prompt_logprobs}, "
|
||||||
|
f"which is is greater than max allowed: {max_logprobs}")
|
||||||
|
|
||||||
def _validate_sampling_params(
|
def _validate_sampling_params(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -360,8 +360,9 @@ class InputBatch:
|
|||||||
if sampling_params.logprobs == -1
|
if sampling_params.logprobs == -1
|
||||||
else sampling_params.logprobs)
|
else sampling_params.logprobs)
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
self.num_prompt_logprobs[
|
self.num_prompt_logprobs[req_id] = (
|
||||||
req_id] = sampling_params.prompt_logprobs
|
self.vocab_size if sampling_params.prompt_logprobs == -1
|
||||||
|
else sampling_params.prompt_logprobs)
|
||||||
|
|
||||||
if sampling_params.allowed_token_ids:
|
if sampling_params.allowed_token_ids:
|
||||||
self.has_allowed_token_ids.add(req_id)
|
self.has_allowed_token_ids.add(req_id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user