[Feature][Responses API]Support MCP tools with streaming mode + background mode (#23927)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang 2025-09-04 12:05:10 +08:00 committed by GitHub
parent b5ee1e3261
commit a38f8bd54c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 26 deletions

View File

@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_streaming(client: OpenAI, model_name: str): @pytest.mark.parametrize("background", [True, False])
async def test_streaming(client: OpenAI, model_name: str, background: bool):
# TODO: Add back when web search and code interpreter are available in CI # TODO: Add back when web search and code interpreter are available in CI
prompts = [ prompts = [
"tell me a story about a cat in 20 words", "tell me a story about a cat in 20 words",
@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str):
# }, # },
], ],
stream=True, stream=True,
background=background,
) )
events = [] events = []
current_event_mode = None current_event_mode = None
resp_id = None
async for event in response: async for event in response:
if event.type == "response.created":
resp_id = event.response.id
if current_event_mode != event.type: if current_event_mode != event.type:
current_event_mode = event.type current_event_mode = event.type
print(f"\n[{event.type}] ", end="", flush=True) print(f"\n[{event.type}] ", end="", flush=True)
@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str):
assert len(events) > 0 assert len(events) > 0
if background:
starting_after = 5
async with await client.responses.retrieve(
response_id=resp_id,
stream=True,
starting_after=starting_after) as stream:
counter = starting_after
async for event in stream:
counter += 1
assert event == events[counter]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])

View File

@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
@router.get("/v1/responses/{response_id}") @router.get("/v1/responses/{response_id}")
async def retrieve_responses(response_id: str, raw_request: Request): async def retrieve_responses(
response_id: str,
raw_request: Request,
starting_after: Optional[int] = None,
stream: Optional[bool] = False,
):
handler = responses(raw_request) handler = responses(raw_request)
if handler is None: if handler is None:
return base(raw_request).create_error_response( return base(raw_request).create_error_response(
message="The model does not support Responses API") message="The model does not support Responses API")
try: try:
response = await handler.retrieve_responses(response_id) response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e detail=str(e)) from e
@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.error.code) status_code=response.error.code)
elif stream:
return StreamingResponse(content=response,
media_type="text/event-stream")
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())

View File

@ -4,6 +4,7 @@
import asyncio import asyncio
import json import json
import time import time
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Sequence from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
@ -55,7 +56,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
@ -168,6 +169,11 @@ class OpenAIServingResponses(OpenAIServing):
# never remove messages from the store. # never remove messages from the store.
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {} self.background_tasks: dict[str, asyncio.Task] = {}
self.tool_server = tool_server self.tool_server = tool_server
@ -249,15 +255,6 @@ class OpenAIServingResponses(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance(
self.tool_server,
MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
return self.create_error_response(
"MCP tool server is not supported in background mode and "
"streaming mode")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
@ -329,25 +326,44 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response self.response_store[response.id] = response
# Run the request in the background. # Run the request in the background.
task = asyncio.create_task( if request.stream:
self._run_background_request( task = asyncio.create_task(
request, self._run_background_request_stream(
sampling_params, request,
result_generator, sampling_params,
context, result_generator,
model_name, context,
tokenizer, model_name,
request_metadata, tokenizer,
created_time, request_metadata,
), created_time,
name=f"create_{response.id}", ),
) name=f"create_{request.request_id}",
)
else:
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup. # For cleanup.
response_id = response.id response_id = response.id
self.background_tasks[response_id] = task self.background_tasks[response_id] = task
task.add_done_callback( task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None)) lambda _: self.background_tasks.pop(response_id, None))
if request.stream:
return self.responses_background_stream_generator(
request.request_id)
return response return response
if request.stream: if request.stream:
@ -736,6 +752,40 @@ class OpenAIServingResponses(OpenAIServing):
prev_outputs.append(response_msg) prev_outputs.append(response_msg)
return messages return messages
async def _run_background_request_stream(
self,
request: ResponsesRequest,
*args,
**kwargs,
):
event_deque: deque[str] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
try:
generator = self.responses_stream_generator(
request, *args, **kwargs)
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except Exception as e:
logger.exception("Background request failed for %s",
request.request_id)
response = self.create_error_response(str(e))
finally:
# Mark as finished with a special marker
event_deque.append("__STREAM_END__")
new_event_signal.set()
if response is not None and isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
response_id = request.request_id
async with self.response_store_lock:
stored_response = self.response_store.get(response_id)
assert stored_response is not None
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def _run_background_request( async def _run_background_request(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
@ -759,9 +809,36 @@ class OpenAIServingResponses(OpenAIServing):
if stored_response.status not in ("completed", "cancelled"): if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed" stored_response.status = "failed"
async def responses_background_stream_generator(
self,
response_id: str,
starting_after: Optional[int] = None,
):
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
event_deque, new_event_signal = self.event_store[response_id]
start_index = 0 if starting_after is None else starting_after + 1
current_index = start_index
while True:
new_event_signal.clear()
# Yield existing events from start_index
while current_index < len(event_deque):
event = event_deque[current_index]
if event == "__STREAM_END__":
return
yield event
current_index += 1
await new_event_signal.wait()
async def retrieve_responses( async def retrieve_responses(
self, self,
response_id: str, response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse]: ) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"): if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id) return self._make_invalid_id_error(response_id)
@ -771,6 +848,12 @@ class OpenAIServingResponses(OpenAIServing):
if response is None: if response is None:
return self._make_not_found_error(response_id) return self._make_not_found_error(response_id)
if stream:
return self.responses_background_stream_generator(
response_id,
starting_after,
)
return response return response
async def cancel_responses( async def cancel_responses(