mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[Feature][Responses API]Support MCP tools with streaming mode + background mode (#23927)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
parent
b5ee1e3261
commit
a38f8bd54c
@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_streaming(client: OpenAI, model_name: str):
|
||||
@pytest.mark.parametrize("background", [True, False])
|
||||
async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
# TODO: Add back when web search and code interpreter are available in CI
|
||||
prompts = [
|
||||
"tell me a story about a cat in 20 words",
|
||||
@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str):
|
||||
# },
|
||||
],
|
||||
stream=True,
|
||||
background=background,
|
||||
)
|
||||
|
||||
events = []
|
||||
current_event_mode = None
|
||||
resp_id = None
|
||||
async for event in response:
|
||||
if event.type == "response.created":
|
||||
resp_id = event.response.id
|
||||
|
||||
if current_event_mode != event.type:
|
||||
current_event_mode = event.type
|
||||
print(f"\n[{event.type}] ", end="", flush=True)
|
||||
@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str):
|
||||
|
||||
assert len(events) > 0
|
||||
|
||||
if background:
|
||||
starting_after = 5
|
||||
async with await client.responses.retrieve(
|
||||
response_id=resp_id,
|
||||
stream=True,
|
||||
starting_after=starting_after) as stream:
|
||||
counter = starting_after
|
||||
async for event in stream:
|
||||
counter += 1
|
||||
assert event == events[counter]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
|
||||
@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
async def retrieve_responses(response_id: str, raw_request: Request):
|
||||
async def retrieve_responses(
|
||||
response_id: str,
|
||||
raw_request: Request,
|
||||
starting_after: Optional[int] = None,
|
||||
stream: Optional[bool] = False,
|
||||
):
|
||||
handler = responses(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Responses API")
|
||||
|
||||
try:
|
||||
response = await handler.retrieve_responses(response_id)
|
||||
response = await handler.retrieve_responses(
|
||||
response_id,
|
||||
starting_after=starting_after,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
detail=str(e)) from e
|
||||
@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.error.code)
|
||||
elif stream:
|
||||
return StreamingResponse(content=response,
|
||||
media_type="text/event-stream")
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from copy import copy
|
||||
@ -55,7 +56,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
@ -168,6 +169,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# never remove messages from the store.
|
||||
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||
|
||||
# HACK(wuhang): This is a hack. We should use a better store.
|
||||
# FIXME: If enable_store=True, this may cause a memory leak since we
|
||||
# never remove events from the store.
|
||||
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
|
||||
|
||||
self.background_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
self.tool_server = tool_server
|
||||
@ -249,15 +255,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
if self.tool_server is not None and isinstance(
|
||||
self.tool_server,
|
||||
MCPToolServer) and request.stream and request.tools and any(
|
||||
tool.type in ["web_search_preview", "code_interpreter"]
|
||||
for tool in request.tools):
|
||||
return self.create_error_response(
|
||||
"MCP tool server is not supported in background mode and "
|
||||
"streaming mode")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
|
||||
@ -329,25 +326,44 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self.response_store[response.id] = response
|
||||
|
||||
# Run the request in the background.
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
if request.stream:
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request_stream(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{request.request_id}",
|
||||
)
|
||||
else:
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
|
||||
# For cleanup.
|
||||
response_id = response.id
|
||||
self.background_tasks[response_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self.background_tasks.pop(response_id, None))
|
||||
|
||||
if request.stream:
|
||||
return self.responses_background_stream_generator(
|
||||
request.request_id)
|
||||
return response
|
||||
|
||||
if request.stream:
|
||||
@ -736,6 +752,40 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
prev_outputs.append(response_msg)
|
||||
return messages
|
||||
|
||||
async def _run_background_request_stream(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
event_deque: deque[str] = deque()
|
||||
new_event_signal = asyncio.Event()
|
||||
self.event_store[request.request_id] = (event_deque, new_event_signal)
|
||||
response = None
|
||||
try:
|
||||
generator = self.responses_stream_generator(
|
||||
request, *args, **kwargs)
|
||||
async for event in generator:
|
||||
event_deque.append(event)
|
||||
new_event_signal.set() # Signal new event available
|
||||
except Exception as e:
|
||||
logger.exception("Background request failed for %s",
|
||||
request.request_id)
|
||||
response = self.create_error_response(str(e))
|
||||
finally:
|
||||
# Mark as finished with a special marker
|
||||
event_deque.append("__STREAM_END__")
|
||||
new_event_signal.set()
|
||||
|
||||
if response is not None and isinstance(response, ErrorResponse):
|
||||
# If the request has failed, update the status to "failed".
|
||||
response_id = request.request_id
|
||||
async with self.response_store_lock:
|
||||
stored_response = self.response_store.get(response_id)
|
||||
assert stored_response is not None
|
||||
if stored_response.status not in ("completed", "cancelled"):
|
||||
stored_response.status = "failed"
|
||||
|
||||
async def _run_background_request(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
@ -759,9 +809,36 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if stored_response.status not in ("completed", "cancelled"):
|
||||
stored_response.status = "failed"
|
||||
|
||||
async def responses_background_stream_generator(
|
||||
self,
|
||||
response_id: str,
|
||||
starting_after: Optional[int] = None,
|
||||
):
|
||||
if response_id not in self.event_store:
|
||||
raise ValueError(f"Unknown response_id: {response_id}")
|
||||
|
||||
event_deque, new_event_signal = self.event_store[response_id]
|
||||
start_index = 0 if starting_after is None else starting_after + 1
|
||||
current_index = start_index
|
||||
|
||||
while True:
|
||||
new_event_signal.clear()
|
||||
|
||||
# Yield existing events from start_index
|
||||
while current_index < len(event_deque):
|
||||
event = event_deque[current_index]
|
||||
if event == "__STREAM_END__":
|
||||
return
|
||||
yield event
|
||||
current_index += 1
|
||||
|
||||
await new_event_signal.wait()
|
||||
|
||||
async def retrieve_responses(
|
||||
self,
|
||||
response_id: str,
|
||||
starting_after: Optional[int],
|
||||
stream: Optional[bool],
|
||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||
if not response_id.startswith("resp_"):
|
||||
return self._make_invalid_id_error(response_id)
|
||||
@ -771,6 +848,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
if response is None:
|
||||
return self._make_not_found_error(response_id)
|
||||
|
||||
if stream:
|
||||
return self.responses_background_stream_generator(
|
||||
response_id,
|
||||
starting_after,
|
||||
)
|
||||
return response
|
||||
|
||||
async def cancel_responses(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user