mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 17:29:44 +08:00
[Feature][Responses API] Support MCP tool in background mode (#23494)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
parent
b1625dbe9c
commit
6891205b16
@ -4,13 +4,15 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,6 +39,11 @@ class ConversationContext(ABC):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
|
||||
@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
|
||||
def render_for_completion(self) -> list[int]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: list,
|
||||
tool_sessions: dict[str, Tool],
|
||||
available_tools: list[str],
|
||||
):
|
||||
self._messages = messages
|
||||
self.tool_sessions = tool_sessions
|
||||
self.available_tools = available_tools
|
||||
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
self.num_init_messages = len(messages)
|
||||
@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
|
||||
if recipient is not None:
|
||||
if recipient.startswith("browser."):
|
||||
return await self.call_search_tool(
|
||||
self.tool_sessions["browser"], last_msg)
|
||||
self._tool_sessions["browser"], last_msg)
|
||||
elif recipient.startswith("python"):
|
||||
return await self.call_python_tool(
|
||||
self.tool_sessions["python"], last_msg)
|
||||
self._tool_sessions["python"], last_msg)
|
||||
raise ValueError("No tool call found")
|
||||
|
||||
def render_for_completion(self) -> list[int]:
|
||||
@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
|
||||
recipient=Role.ASSISTANT)
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack) -> None:
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
self._tool_sessions[
|
||||
tool_name] = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name))
|
||||
|
||||
|
||||
class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from copy import copy
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Final, Optional, Union
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import openai.types.responses as openai_responses_types
|
||||
@ -248,10 +248,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
if self.tool_server is not None and isinstance(
|
||||
self.tool_server, MCPToolServer
|
||||
) and (request.background or request.stream) and request.tools and any(
|
||||
tool.type in ["web_search_preview", "code_interpreter"]
|
||||
for tool in request.tools):
|
||||
self.tool_server,
|
||||
MCPToolServer) and request.stream and request.tools and any(
|
||||
tool.type in ["web_search_preview", "code_interpreter"]
|
||||
for tool in request.tools):
|
||||
return self.create_error_response(
|
||||
"MCP tool server is not supported in background mode and "
|
||||
"streaming mode")
|
||||
@ -265,103 +265,70 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
builtin_tool_list.append("python")
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
if self.tool_server is not None:
|
||||
# TODO: initialize tool sessions lazily when the session
|
||||
# is actually used.
|
||||
tool_session_ctxs: dict[str, Any] = {
|
||||
tool_name:
|
||||
exit_stack.enter_async_context(
|
||||
self.tool_server.new_session(tool_name))
|
||||
for tool_name in builtin_tool_list
|
||||
}
|
||||
tool_sessions = {}
|
||||
for tool_name in builtin_tool_list:
|
||||
tool_sessions[tool_name] = (
|
||||
await tool_session_ctxs[tool_name])
|
||||
else:
|
||||
assert len(builtin_tool_list) == 0
|
||||
tool_sessions = {}
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(
|
||||
raw_request.headers))
|
||||
if self.tool_server is not None:
|
||||
available_tools = builtin_tool_list
|
||||
else:
|
||||
assert len(builtin_tool_list) == 0
|
||||
available_tools = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
context: ConversationContext
|
||||
if self.use_harmony:
|
||||
if request.stream:
|
||||
context = StreamingHarmonyContext(
|
||||
messages, tool_sessions)
|
||||
else:
|
||||
context = HarmonyContext(messages, tool_sessions)
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
context: ConversationContext
|
||||
if self.use_harmony:
|
||||
if request.stream:
|
||||
context = StreamingHarmonyContext(
|
||||
messages, available_tools)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
request_prompt=request_prompts[i],
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
|
||||
# Store the input messages.
|
||||
if request.store:
|
||||
self.msg_store[request.request_id] = messages
|
||||
|
||||
if request.background:
|
||||
created_time = int(time.time())
|
||||
response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=created_time,
|
||||
output=[],
|
||||
status="queued",
|
||||
usage=None,
|
||||
context = HarmonyContext(messages, available_tools)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
request_prompt=request_prompts[i],
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
async with self.response_store_lock:
|
||||
self.response_store[response.id] = response
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Run the request in the background.
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
|
||||
# For cleanup.
|
||||
response_id = response.id
|
||||
self.background_tasks[response_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self.background_tasks.pop(response_id, None))
|
||||
return response
|
||||
# Store the input messages.
|
||||
if request.store:
|
||||
self.msg_store[request.request_id] = messages
|
||||
|
||||
if request.stream:
|
||||
return self.responses_stream_generator(
|
||||
if request.background:
|
||||
created_time = int(time.time())
|
||||
response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=created_time,
|
||||
output=[],
|
||||
status="queued",
|
||||
usage=None,
|
||||
)
|
||||
async with self.response_store_lock:
|
||||
self.response_store[response.id] = response
|
||||
|
||||
# Run the request in the background.
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
@ -369,21 +336,41 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.responses_full_generator(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response("Should not reach here")
|
||||
# For cleanup.
|
||||
response_id = response.id
|
||||
self.background_tasks[response_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self.background_tasks.pop(response_id, None))
|
||||
return response
|
||||
|
||||
if request.stream:
|
||||
return self.responses_stream_generator(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self.responses_full_generator(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
@ -439,14 +426,16 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if created_time is None:
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if self.use_harmony:
|
||||
assert isinstance(context, HarmonyContext)
|
||||
@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
async def responses_stream_generator(
|
||||
async def _process_streaming_events(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: Optional[int] = None,
|
||||
created_time: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# TODO:
|
||||
# 1. Handle disconnect
|
||||
|
||||
if not isinstance(context, StreamingHarmonyContext):
|
||||
raise NotImplementedError(
|
||||
"Streaming is not supported for responses API without Harmony."
|
||||
)
|
||||
|
||||
created_time = created_time or int(time.time())
|
||||
|
||||
sequence_number = 0
|
||||
|
||||
def _send_event(event: BaseModel):
|
||||
@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
sequence_number=-1,
|
||||
response=final_response.model_dump(),
|
||||
))
|
||||
|
||||
async def responses_stream_generator(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
result_generator: AsyncIterator[Optional[ConversationContext]],
|
||||
context: ConversationContext,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
created_time: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# TODO:
|
||||
# 1. Handle disconnect
|
||||
|
||||
if not isinstance(context, StreamingHarmonyContext):
|
||||
raise NotImplementedError(
|
||||
"Streaming is not supported for responses API without Harmony."
|
||||
)
|
||||
|
||||
created_time = created_time or int(time.time())
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||
async for event_data in self._process_streaming_events(
|
||||
request, sampling_params, result_generator, context,
|
||||
model_name, tokenizer, request_metadata, created_time):
|
||||
yield event_data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user