[Feature][Responses API] Support MCP tool in background mode (#23494)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang 2025-08-27 09:06:58 +08:00 committed by GitHub
parent b1625dbe9c
commit 6891205b16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 162 additions and 134 deletions

View File

@ -4,13 +4,15 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
if TYPE_CHECKING:
@ -37,6 +39,11 @@ class ConversationContext(ABC):
def render_for_completion(self) -> list[int]:
pass
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class SimpleContext(ConversationContext):
@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
tool_sessions: dict[str, Tool],
available_tools: list[str],
):
self._messages = messages
self.tool_sessions = tool_sessions
self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg)
self._tool_sessions["browser"], last_msg)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg)
self._tool_sessions["python"], last_msg)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
recipient=Role.ASSISTANT)
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
self._tool_sessions[
tool_name] = await exit_stack.enter_async_context(
tool_server.new_session(tool_name))
class StreamingHarmonyContext(HarmonyContext):

View File

@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union
from typing import Callable, Final, Optional, Union
import jinja2
import openai.types.responses as openai_responses_types
@ -248,10 +248,10 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance(
self.tool_server, MCPToolServer
) and (request.background or request.stream) and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
self.tool_server,
MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
return self.create_error_response(
"MCP tool server is not supported in background mode and "
"streaming mode")
@ -265,103 +265,70 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
else:
assert len(builtin_tool_list) == 0
tool_sessions = {}
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(
raw_request.headers))
if self.tool_server is not None:
available_tools = builtin_tool_list
else:
assert len(builtin_tool_list) == 0
available_tools = []
try:
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(
messages, tool_sessions)
else:
context = HarmonyContext(messages, tool_sessions)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(
messages, available_tools)
else:
context = SimpleContext()
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
context = HarmonyContext(messages, available_tools)
else:
context = SimpleContext()
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
)
async with self.response_store_lock:
self.response_store[response.id] = response
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
assert len(generators) == 1
result_generator, = generators
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
if request.stream:
return self.responses_stream_generator(
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
async with self.response_store_lock:
self.response_store[response.id] = response
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
@ -369,21 +336,41 @@ class OpenAIServingResponses(OpenAIServing):
model_name,
tokenizer,
request_metadata,
)
created_time,
),
name=f"create_{response.id}",
)
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response("Should not reach here")
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
if request.stream:
return self.responses_stream_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
except Exception as e:
return self.create_error_response(str(e))
async def _make_request(
self,
@ -439,14 +426,16 @@ class OpenAIServingResponses(OpenAIServing):
if created_time is None:
created_time = int(time.time())
try:
async for _ in result_generator:
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async with AsyncExitStack() as exit_stack:
try:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for _ in result_generator:
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if self.use_harmony:
assert isinstance(context, HarmonyContext)
@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST,
)
async def responses_stream_generator(
async def _process_streaming_events(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
created_time: int,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
sequence_number = 0
def _send_event(event: BaseModel):
@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1,
response=final_response.model_dump(),
))
async def responses_stream_generator(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[Optional[ConversationContext]],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
async with AsyncExitStack() as exit_stack:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for event_data in self._process_streaming_events(
request, sampling_params, result_generator, context,
model_name, tokenizer, request_metadata, created_time):
yield event_data