[Frontend] Do prompt_logprobs clamping for chat as well as completions (#14225)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-04 21:13:06 +01:00 committed by GitHub
parent 9badee53de
commit e5b2f1601a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 11 deletions

View File

@ -24,7 +24,8 @@ from vllm.entrypoints.openai.protocol import (
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
@ -844,7 +845,7 @@ class OpenAIServingChat(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
)
return response

View File

@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
@ -394,13 +395,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs
if prompt_logprobs:
for logprob_dict in prompt_logprobs:
if logprob_dict:
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float('-inf'):
logprob_values.logprob = -9999.0
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
prompt_text = final_res.prompt
token_ids: GenericSequence[int]

View File

@ -42,7 +42,7 @@ from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@ -535,3 +535,18 @@ class OpenAIServing:
if model_name is None:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs,
None]) -> Union[PromptLogprobs, None]:
if prompt_logprobs is None:
return prompt_logprobs
for logprob_dict in prompt_logprobs:
if logprob_dict is None:
continue
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float('-inf'):
logprob_values.logprob = -9999.0
return prompt_logprobs