mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 05:55:01 +08:00
[Frontend] Do prompt_logprobs clamping for chat as well as completions (#14225)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
9badee53de
commit
e5b2f1601a
@ -24,7 +24,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||||
ReasoningParserManager)
|
ReasoningParserManager)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||||
|
clamp_prompt_logprobs)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
@ -844,7 +845,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
prompt_logprobs=final_res.prompt_logprobs,
|
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -23,7 +23,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||||
|
clamp_prompt_logprobs)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -394,13 +395,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
for final_res in final_res_batch:
|
for final_res in final_res_batch:
|
||||||
prompt_token_ids = final_res.prompt_token_ids
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
assert prompt_token_ids is not None
|
assert prompt_token_ids is not None
|
||||||
prompt_logprobs = final_res.prompt_logprobs
|
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
|
||||||
if prompt_logprobs:
|
|
||||||
for logprob_dict in prompt_logprobs:
|
|
||||||
if logprob_dict:
|
|
||||||
for logprob_values in logprob_dict.values():
|
|
||||||
if logprob_values.logprob == float('-inf'):
|
|
||||||
logprob_values.logprob = -9999.0
|
|
||||||
prompt_text = final_res.prompt
|
prompt_text = final_res.prompt
|
||||||
|
|
||||||
token_ids: GenericSequence[int]
|
token_ids: GenericSequence[int]
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob, PromptLogprobs
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
@ -535,3 +535,18 @@ class OpenAIServing:
|
|||||||
if model_name is None:
|
if model_name is None:
|
||||||
return self.models.base_model_paths[0].name
|
return self.models.base_model_paths[0].name
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_prompt_logprobs(
|
||||||
|
prompt_logprobs: Union[PromptLogprobs,
|
||||||
|
None]) -> Union[PromptLogprobs, None]:
|
||||||
|
if prompt_logprobs is None:
|
||||||
|
return prompt_logprobs
|
||||||
|
|
||||||
|
for logprob_dict in prompt_logprobs:
|
||||||
|
if logprob_dict is None:
|
||||||
|
continue
|
||||||
|
for logprob_values in logprob_dict.values():
|
||||||
|
if logprob_values.logprob == float('-inf'):
|
||||||
|
logprob_values.logprob = -9999.0
|
||||||
|
return prompt_logprobs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user