mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 04:08:43 +08:00
[Refactor] to simplify and extract the shared logic between chat completion and responses (#27961)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
parent
e261d37c9a
commit
0976711f3b
@ -13,7 +13,6 @@ import partial_json_parser
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
@ -47,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
ToolCall,
|
||||
@ -1394,6 +1391,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
auto_tools_called = False
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
tool_calls, content = self._parse_tool_calls_from_content(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
content=content,
|
||||
enable_auto_tools=self.enable_auto_tools,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
)
|
||||
tool_call_class = (
|
||||
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
|
||||
)
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and (
|
||||
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
|
||||
and request.tool_choice != "required"
|
||||
@ -1407,63 +1414,33 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request.tool_choice
|
||||
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
tool_call_class = (
|
||||
MistralToolCall
|
||||
if isinstance(tokenizer, MistralTokenizer)
|
||||
else ToolCall
|
||||
)
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(
|
||||
function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=content,
|
||||
)
|
||||
)
|
||||
],
|
||||
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
|
||||
)
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = (
|
||||
MistralToolCall
|
||||
if isinstance(tokenizer, MistralTokenizer)
|
||||
else ToolCall
|
||||
)
|
||||
|
||||
# the fields of FunctionDefinition are a superset of the
|
||||
# tool call outputs and can be used for parsing
|
||||
assert content is not None
|
||||
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
|
||||
content
|
||||
)
|
||||
tool_call_ids = []
|
||||
tool_call_class_items = []
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
for tool_call in tool_calls:
|
||||
tool_call_ids.append(
|
||||
make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt,
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(
|
||||
id=make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt,
|
||||
),
|
||||
function=tool_call,
|
||||
)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(
|
||||
id=tool_call_ids[i],
|
||||
function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(
|
||||
tool_call.parameters, ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for i, tool_call in enumerate(tool_calls)
|
||||
],
|
||||
tool_calls=tool_call_class_items,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@ -1481,25 +1458,22 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self.enable_auto_tools
|
||||
and self.tool_parser
|
||||
):
|
||||
try:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
content if content is not None else "", request=request
|
||||
)
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
|
||||
if tool_calls:
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
content=content,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=tc,
|
||||
type="function",
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
else:
|
||||
@ -1509,8 +1483,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# try to use content return from tool parser first,
|
||||
# tool parser may do some modify for the content.
|
||||
if tool_call_info.content and len(tool_call_info.content) > 0:
|
||||
ret_content = tool_call_info.content
|
||||
if content and len(content) > 0:
|
||||
ret_content = content
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
@ -21,6 +21,10 @@ if sys.version_info >= (3, 12):
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from openai.types.responses import (
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -36,6 +40,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
from vllm.entrypoints.context import ConversationContext
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
@ -49,6 +54,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse,
|
||||
RerankRequest,
|
||||
@ -1305,6 +1312,75 @@ class OpenAIServing:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_calls_from_content(
|
||||
request: ResponsesRequest | ChatCompletionRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
enable_auto_tools: bool,
|
||||
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
|
||||
content: str | None = None,
|
||||
) -> tuple[list[FunctionCall] | None, str | None]:
|
||||
function_calls = list[FunctionCall]()
|
||||
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
|
||||
assert content is not None
|
||||
# Forced Function Call
|
||||
function_calls.append(
|
||||
FunctionCall(name=request.tool_choice.name, arguments=content)
|
||||
)
|
||||
content = None # Clear content since tool is called.
|
||||
elif request.tool_choice and isinstance(
|
||||
request.tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
assert content is not None
|
||||
# Forced Function Call
|
||||
function_calls.append(
|
||||
FunctionCall(name=request.tool_choice.function.name, arguments=content)
|
||||
)
|
||||
content = None # Clear content since tool is called.
|
||||
elif request.tool_choice == "required":
|
||||
assert content is not None
|
||||
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
|
||||
function_calls.extend(
|
||||
[
|
||||
FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
)
|
||||
content = None # Clear content since tool is called.
|
||||
elif (
|
||||
tool_parser_cls
|
||||
and enable_auto_tools
|
||||
and (request.tool_choice == "auto" or request.tool_choice is None)
|
||||
):
|
||||
# Automatic Tool Call Parsing
|
||||
try:
|
||||
tool_parser = tool_parser_cls(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in tool parser creation.")
|
||||
raise e
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
content if content is not None else "",
|
||||
request=request, # type: ignore
|
||||
)
|
||||
if tool_call_info is not None and tool_call_info.tools_called:
|
||||
# extract_tool_calls() returns a list of tool calls.
|
||||
function_calls.extend(
|
||||
FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_call_info.tool_calls
|
||||
)
|
||||
content = tool_call_info.content
|
||||
else:
|
||||
# No tool calls.
|
||||
return None, content
|
||||
|
||||
return function_calls, content
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user