[Refactor] to simplify and extract the shared logic between chat completion and responses (#27961)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-11-05 15:46:39 +08:00 committed by GitHub
parent e261d37c9a
commit 0976711f3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 62 deletions

View File

@ -13,7 +13,6 @@ import partial_json_parser
import regex as re
from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
@ -47,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (
DeltaMessage,
DeltaToolCall,
ErrorResponse,
FunctionCall,
FunctionDefinition,
PromptTokenUsageInfo,
RequestResponseMetadata,
ToolCall,
@ -1394,6 +1391,16 @@ class OpenAIServingChat(OpenAIServing):
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
tool_calls, content = self._parse_tool_calls_from_content(
request=request,
tokenizer=tokenizer,
content=content,
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
@ -1407,63 +1414,33 @@ class OpenAIServingChat(OpenAIServing):
request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
):
tool_call_class = (
MistralToolCall
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)
assert tool_calls is not None and len(tool_calls) > 0
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content="",
tool_calls=[
tool_call_class(
function=FunctionCall(
name=request.tool_choice.function.name,
arguments=content,
)
)
],
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
)
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = (
MistralToolCall
if isinstance(tokenizer, MistralTokenizer)
else ToolCall
)
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
content
)
tool_call_ids = []
tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0
for tool_call in tool_calls:
tool_call_ids.append(
make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
tool_call_class_items.append(
tool_call_class(
id=make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
function=tool_call,
)
)
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(
id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(
tool_call.parameters, ensure_ascii=False
),
),
)
for i, tool_call in enumerate(tool_calls)
],
tool_calls=tool_call_class_items,
reasoning_content=reasoning_content,
)
@ -1481,25 +1458,22 @@ class OpenAIServingChat(OpenAIServing):
and self.enable_auto_tools
and self.tool_parser
):
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "", request=request
)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls:
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls,
content=content,
tool_calls=[
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
)
else:
@ -1509,8 +1483,8 @@ class OpenAIServingChat(OpenAIServing):
# try to use content return from tool parser first,
# tool parser may do some modify for the content.
if tool_call_info.content and len(tool_call_info.content) > 0:
ret_content = tool_call_info.content
if content and len(content) > 0:
ret_content = content
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,

View File

@ -12,7 +12,7 @@ from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
@ -21,6 +21,10 @@ if sys.version_info >= (3, 12):
else:
from typing_extensions import TypedDict
from openai.types.responses import (
ToolChoiceFunction,
)
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.engine.protocol import EngineClient
@ -36,6 +40,7 @@ from vllm.entrypoints.chat_utils import (
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
@ -49,6 +54,8 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
IOProcessorRequest,
PoolingResponse,
RerankRequest,
@ -1305,6 +1312,75 @@ class OpenAIServing:
except ValueError:
return None
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: AnyTokenizer,
enable_auto_tools: bool,
tool_parser_cls: Callable[[AnyTokenizer], ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.function.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice == "required":
assert content is not None
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
function_calls.extend(
[
FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
)
for tool_call in tool_calls
]
)
content = None # Clear content since tool is called.
elif (
tool_parser_cls
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
):
# Automatic Tool Call Parsing
try:
tool_parser = tool_parser_cls(tokenizer)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
raise e
tool_call_info = tool_parser.extract_tool_calls(
content if content is not None else "",
request=request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
else:
# No tool calls.
return None, content
return function_calls, content
@staticmethod
def _get_decoded_token(
logprob: Logprob,