From 6891205b161e78ea6e255da194a4470e06997a3b Mon Sep 17 00:00:00 2001 From: wuhang Date: Wed, 27 Aug 2025 09:06:58 +0800 Subject: [PATCH] [Feature][Responses API] Support MCP tool in background mode (#23494) Signed-off-by: wuhang --- vllm/entrypoints/context.py | 31 ++- vllm/entrypoints/openai/serving_responses.py | 265 ++++++++++--------- 2 files changed, 162 insertions(+), 134 deletions(-) diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index f70e1fc207f86..9d587e8669339 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -4,13 +4,15 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, Union +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Optional, Union from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( get_encoding, get_streamable_parser_for_assistant, render_for_completion) from vllm.entrypoints.tool import Tool +from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput if TYPE_CHECKING: @@ -37,6 +39,11 @@ class ConversationContext(ABC): def render_for_completion(self) -> list[int]: pass + @abstractmethod + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class SimpleContext(ConversationContext): @@ -55,16 +62,21 @@ class SimpleContext(ConversationContext): def render_for_completion(self) -> list[int]: raise NotImplementedError("Should not be called.") + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + pass + class HarmonyContext(ConversationContext): def __init__( self, messages: list, - tool_sessions: dict[str, Tool], + available_tools: list[str], ): self._messages = messages - self.tool_sessions = tool_sessions + self.available_tools = available_tools + self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) @@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext): if recipient is not None: if recipient.startswith("browser."): return await self.call_search_tool( - self.tool_sessions["browser"], last_msg) + self._tool_sessions["browser"], last_msg) elif recipient.startswith("python"): return await self.call_python_tool( - self.tool_sessions["python"], last_msg) + self._tool_sessions["python"], last_msg) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: @@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext): recipient=Role.ASSISTANT) ] + async def init_tool_sessions(self, tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack) -> None: + if tool_server: + for tool_name in self.available_tools: + if tool_name not in self._tool_sessions: + self._tool_sessions[ + tool_name] = await exit_stack.enter_async_context( + tool_server.new_session(tool_name)) + class StreamingHarmonyContext(HarmonyContext): diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 5adcb310e3468..67eec2d523e3f 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Any, Callable, Final, Optional, Union +from typing import Callable, Final, Optional, Union import jinja2 import openai.types.responses as openai_responses_types @@ -248,10 +248,10 @@ class OpenAIServingResponses(OpenAIServing): raw_request.state.request_metadata = request_metadata if self.tool_server is not None and isinstance( - self.tool_server, MCPToolServer - ) and (request.background or request.stream) and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): + self.tool_server, + MCPToolServer) and request.stream and request.tools and any( + tool.type in ["web_search_preview", "code_interpreter"] + for tool in request.tools): return self.create_error_response( "MCP tool server is not supported in background mode and " "streaming mode") @@ -265,103 +265,70 @@ class OpenAIServingResponses(OpenAIServing): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): builtin_tool_list.append("python") - async with AsyncExitStack() as exit_stack: - try: - if self.tool_server is not None: - # TODO: initialize tool sessions lazily when the session - # is actually used. - tool_session_ctxs: dict[str, Any] = { - tool_name: - exit_stack.enter_async_context( - self.tool_server.new_session(tool_name)) - for tool_name in builtin_tool_list - } - tool_sessions = {} - for tool_name in builtin_tool_list: - tool_sessions[tool_name] = ( - await tool_session_ctxs[tool_name]) - else: - assert len(builtin_tool_list) == 0 - tool_sessions = {} - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - trace_headers = (None if raw_request is None else await - self._get_trace_headers( - raw_request.headers)) + if self.tool_server is not None: + available_tools = builtin_tool_list + else: + assert len(builtin_tool_list) == 0 + available_tools = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext( - messages, tool_sessions) - else: - context = HarmonyContext(messages, tool_sessions) + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, available_tools) else: - context = SimpleContext() - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - request_prompt=request_prompts[i], - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, - ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) - - assert len(generators) == 1 - result_generator, = generators - - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages - - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, + context = HarmonyContext(messages, available_tools) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, ) - async with self.response_store_lock: - self.response_store[response.id] = response + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) - # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + assert len(generators) == 1 + result_generator, = generators - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages - if request.stream: - return self.responses_stream_generator( + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( request, sampling_params, result_generator, @@ -369,21 +336,41 @@ class OpenAIServingResponses(OpenAIServing): model_name, tokenizer, request_metadata, - ) + created_time, + ), + name=f"create_{response.id}", + ) - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) - return self.create_error_response("Should not reach here") + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response + + if request.stream: + return self.responses_stream_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) async def _make_request( self, @@ -439,14 +426,16 @@ class OpenAIServingResponses(OpenAIServing): if created_time is None: created_time = int(time.time()) - try: - async for _ in result_generator: - pass - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with AsyncExitStack() as exit_stack: + try: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for _ in result_generator: + pass + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) if self.use_harmony: assert isinstance(context, HarmonyContext) @@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing): status_code=HTTPStatus.BAD_REQUEST, ) - async def responses_stream_generator( + async def _process_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, @@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing): model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, + created_time: int, ) -> AsyncGenerator[str, None]: - # TODO: - # 1. Handle disconnect - - if not isinstance(context, StreamingHarmonyContext): - raise NotImplementedError( - "Streaming is not supported for responses API without Harmony." - ) - - created_time = created_time or int(time.time()) - sequence_number = 0 def _send_event(event: BaseModel): @@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing): sequence_number=-1, response=final_response.model_dump(), )) + + async def responses_stream_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[Optional[ConversationContext]], + context: ConversationContext, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> AsyncGenerator[str, None]: + # TODO: + # 1. Handle disconnect + + if not isinstance(context, StreamingHarmonyContext): + raise NotImplementedError( + "Streaming is not supported for responses API without Harmony." + ) + + created_time = created_time or int(time.time()) + + async with AsyncExitStack() as exit_stack: + await context.init_tool_sessions(self.tool_server, exit_stack) + async for event_data in self._process_streaming_events( + request, sampling_params, result_generator, context, + model_name, tokenizer, request_metadata, created_time): + yield event_data