[Feature][Responses API]Support MCP tools with streaming mode + background mode (#23927)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang 2025-09-04 12:05:10 +08:00 committed by GitHub
parent b5ee1e3261
commit a38f8bd54c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 26 deletions

View File

@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_streaming(client: OpenAI, model_name: str):
@pytest.mark.parametrize("background", [True, False])
async def test_streaming(client: OpenAI, model_name: str, background: bool):
# TODO: Add back when web search and code interpreter are available in CI
prompts = [
"tell me a story about a cat in 20 words",
@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str):
# },
],
stream=True,
background=background,
)
events = []
current_event_mode = None
resp_id = None
async for event in response:
if event.type == "response.created":
resp_id = event.response.id
if current_event_mode != event.type:
current_event_mode = event.type
print(f"\n[{event.type}] ", end="", flush=True)
@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str):
assert len(events) > 0
if background:
starting_after = 5
async with await client.responses.retrieve(
response_id=resp_id,
stream=True,
starting_after=starting_after) as stream:
counter = starting_after
async for event in stream:
counter += 1
assert event == events[counter]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])

View File

@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
@router.get("/v1/responses/{response_id}")
async def retrieve_responses(response_id: str, raw_request: Request):
async def retrieve_responses(
response_id: str,
raw_request: Request,
starting_after: Optional[int] = None,
stream: Optional[bool] = False,
):
handler = responses(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Responses API")
try:
response = await handler.retrieve_responses(response_id)
response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e
@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.error.code)
elif stream:
return StreamingResponse(content=response,
media_type="text/event-stream")
return JSONResponse(content=response.model_dump())

View File

@ -4,6 +4,7 @@
import asyncio
import json
import time
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
@ -55,7 +56,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer
from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
@ -168,6 +169,11 @@ class OpenAIServingResponses(OpenAIServing):
# never remove messages from the store.
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
self.tool_server = tool_server
@ -249,15 +255,6 @@ class OpenAIServingResponses(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance(
self.tool_server,
MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
return self.create_error_response(
"MCP tool server is not supported in background mode and "
"streaming mode")
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = []
@ -329,25 +326,44 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
if request.stream:
task = asyncio.create_task(
self._run_background_request_stream(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{request.request_id}",
)
else:
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
if request.stream:
return self.responses_background_stream_generator(
request.request_id)
return response
if request.stream:
@ -736,6 +752,40 @@ class OpenAIServingResponses(OpenAIServing):
prev_outputs.append(response_msg)
return messages
async def _run_background_request_stream(
self,
request: ResponsesRequest,
*args,
**kwargs,
):
event_deque: deque[str] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
try:
generator = self.responses_stream_generator(
request, *args, **kwargs)
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except Exception as e:
logger.exception("Background request failed for %s",
request.request_id)
response = self.create_error_response(str(e))
finally:
# Mark as finished with a special marker
event_deque.append("__STREAM_END__")
new_event_signal.set()
if response is not None and isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
response_id = request.request_id
async with self.response_store_lock:
stored_response = self.response_store.get(response_id)
assert stored_response is not None
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def _run_background_request(
self,
request: ResponsesRequest,
@ -759,9 +809,36 @@ class OpenAIServingResponses(OpenAIServing):
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def responses_background_stream_generator(
self,
response_id: str,
starting_after: Optional[int] = None,
):
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
event_deque, new_event_signal = self.event_store[response_id]
start_index = 0 if starting_after is None else starting_after + 1
current_index = start_index
while True:
new_event_signal.clear()
# Yield existing events from start_index
while current_index < len(event_deque):
event = event_deque[current_index]
if event == "__STREAM_END__":
return
yield event
current_index += 1
await new_event_signal.wait()
async def retrieve_responses(
self,
response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
@ -771,6 +848,12 @@ class OpenAIServingResponses(OpenAIServing):
if response is None:
return self._make_not_found_error(response_id)
if stream:
return self.responses_background_stream_generator(
response_id,
starting_after,
)
return response
async def cancel_responses(