mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 04:05:01 +08:00
[Frontend] Do prompt_logprobs clamping for chat as well as completions (#14225)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
9badee53de
commit
e5b2f1601a
@ -24,7 +24,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
@ -844,7 +845,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
@ -394,13 +395,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
if prompt_logprobs:
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict:
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float('-inf'):
|
||||
logprob_values.logprob = -9999.0
|
||||
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.sequence import Logprob, PromptLogprobs
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
@ -535,3 +535,18 @@ class OpenAIServing:
|
||||
if model_name is None:
|
||||
return self.models.base_model_paths[0].name
|
||||
return model_name
|
||||
|
||||
|
||||
def clamp_prompt_logprobs(
|
||||
prompt_logprobs: Union[PromptLogprobs,
|
||||
None]) -> Union[PromptLogprobs, None]:
|
||||
if prompt_logprobs is None:
|
||||
return prompt_logprobs
|
||||
|
||||
for logprob_dict in prompt_logprobs:
|
||||
if logprob_dict is None:
|
||||
continue
|
||||
for logprob_values in logprob_dict.values():
|
||||
if logprob_values.logprob == float('-inf'):
|
||||
logprob_values.logprob = -9999.0
|
||||
return prompt_logprobs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user