mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[Feature][Responses API]Support MCP tools with streaming mode + background mode (#23927)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
parent
b5ee1e3261
commit
a38f8bd54c
@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_streaming(client: OpenAI, model_name: str):
|
@pytest.mark.parametrize("background", [True, False])
|
||||||
|
async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||||
# TODO: Add back when web search and code interpreter are available in CI
|
# TODO: Add back when web search and code interpreter are available in CI
|
||||||
prompts = [
|
prompts = [
|
||||||
"tell me a story about a cat in 20 words",
|
"tell me a story about a cat in 20 words",
|
||||||
@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str):
|
|||||||
# },
|
# },
|
||||||
],
|
],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
background=background,
|
||||||
)
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
current_event_mode = None
|
current_event_mode = None
|
||||||
|
resp_id = None
|
||||||
async for event in response:
|
async for event in response:
|
||||||
|
if event.type == "response.created":
|
||||||
|
resp_id = event.response.id
|
||||||
|
|
||||||
if current_event_mode != event.type:
|
if current_event_mode != event.type:
|
||||||
current_event_mode = event.type
|
current_event_mode = event.type
|
||||||
print(f"\n[{event.type}] ", end="", flush=True)
|
print(f"\n[{event.type}] ", end="", flush=True)
|
||||||
@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str):
|
|||||||
|
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
|
|
||||||
|
if background:
|
||||||
|
starting_after = 5
|
||||||
|
async with await client.responses.retrieve(
|
||||||
|
response_id=resp_id,
|
||||||
|
stream=True,
|
||||||
|
starting_after=starting_after) as stream:
|
||||||
|
counter = starting_after
|
||||||
|
async for event in stream:
|
||||||
|
counter += 1
|
||||||
|
assert event == events[counter]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
|||||||
@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/responses/{response_id}")
|
@router.get("/v1/responses/{response_id}")
|
||||||
async def retrieve_responses(response_id: str, raw_request: Request):
|
async def retrieve_responses(
|
||||||
|
response_id: str,
|
||||||
|
raw_request: Request,
|
||||||
|
starting_after: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
):
|
||||||
handler = responses(raw_request)
|
handler = responses(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
return base(raw_request).create_error_response(
|
return base(raw_request).create_error_response(
|
||||||
message="The model does not support Responses API")
|
message="The model does not support Responses API")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await handler.retrieve_responses(response_id)
|
response = await handler.retrieve_responses(
|
||||||
|
response_id,
|
||||||
|
starting_after=starting_after,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||||
detail=str(e)) from e
|
detail=str(e)) from e
|
||||||
@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
|
|||||||
if isinstance(response, ErrorResponse):
|
if isinstance(response, ErrorResponse):
|
||||||
return JSONResponse(content=response.model_dump(),
|
return JSONResponse(content=response.model_dump(),
|
||||||
status_code=response.error.code)
|
status_code=response.error.code)
|
||||||
|
elif stream:
|
||||||
|
return StreamingResponse(content=response,
|
||||||
|
media_type="text/event-stream")
|
||||||
return JSONResponse(content=response.model_dump())
|
return JSONResponse(content=response.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@ -55,7 +56,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob as SampleLogprob
|
from vllm.logprobs import Logprob as SampleLogprob
|
||||||
@ -168,6 +169,11 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# never remove messages from the store.
|
# never remove messages from the store.
|
||||||
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
|
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
|
||||||
|
|
||||||
|
# HACK(wuhang): This is a hack. We should use a better store.
|
||||||
|
# FIXME: If enable_store=True, this may cause a memory leak since we
|
||||||
|
# never remove events from the store.
|
||||||
|
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
|
||||||
|
|
||||||
self.background_tasks: dict[str, asyncio.Task] = {}
|
self.background_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
self.tool_server = tool_server
|
self.tool_server = tool_server
|
||||||
@ -249,15 +255,6 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
if raw_request:
|
if raw_request:
|
||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
if self.tool_server is not None and isinstance(
|
|
||||||
self.tool_server,
|
|
||||||
MCPToolServer) and request.stream and request.tools and any(
|
|
||||||
tool.type in ["web_search_preview", "code_interpreter"]
|
|
||||||
for tool in request.tools):
|
|
||||||
return self.create_error_response(
|
|
||||||
"MCP tool server is not supported in background mode and "
|
|
||||||
"streaming mode")
|
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||||
|
|
||||||
@ -329,25 +326,44 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
self.response_store[response.id] = response
|
self.response_store[response.id] = response
|
||||||
|
|
||||||
# Run the request in the background.
|
# Run the request in the background.
|
||||||
task = asyncio.create_task(
|
if request.stream:
|
||||||
self._run_background_request(
|
task = asyncio.create_task(
|
||||||
request,
|
self._run_background_request_stream(
|
||||||
sampling_params,
|
request,
|
||||||
result_generator,
|
sampling_params,
|
||||||
context,
|
result_generator,
|
||||||
model_name,
|
context,
|
||||||
tokenizer,
|
model_name,
|
||||||
request_metadata,
|
tokenizer,
|
||||||
created_time,
|
request_metadata,
|
||||||
),
|
created_time,
|
||||||
name=f"create_{response.id}",
|
),
|
||||||
)
|
name=f"create_{request.request_id}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._run_background_request(
|
||||||
|
request,
|
||||||
|
sampling_params,
|
||||||
|
result_generator,
|
||||||
|
context,
|
||||||
|
model_name,
|
||||||
|
tokenizer,
|
||||||
|
request_metadata,
|
||||||
|
created_time,
|
||||||
|
),
|
||||||
|
name=f"create_{response.id}",
|
||||||
|
)
|
||||||
|
|
||||||
# For cleanup.
|
# For cleanup.
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
self.background_tasks[response_id] = task
|
self.background_tasks[response_id] = task
|
||||||
task.add_done_callback(
|
task.add_done_callback(
|
||||||
lambda _: self.background_tasks.pop(response_id, None))
|
lambda _: self.background_tasks.pop(response_id, None))
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
return self.responses_background_stream_generator(
|
||||||
|
request.request_id)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
@ -736,6 +752,40 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
prev_outputs.append(response_msg)
|
prev_outputs.append(response_msg)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
async def _run_background_request_stream(
|
||||||
|
self,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
event_deque: deque[str] = deque()
|
||||||
|
new_event_signal = asyncio.Event()
|
||||||
|
self.event_store[request.request_id] = (event_deque, new_event_signal)
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
|
generator = self.responses_stream_generator(
|
||||||
|
request, *args, **kwargs)
|
||||||
|
async for event in generator:
|
||||||
|
event_deque.append(event)
|
||||||
|
new_event_signal.set() # Signal new event available
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Background request failed for %s",
|
||||||
|
request.request_id)
|
||||||
|
response = self.create_error_response(str(e))
|
||||||
|
finally:
|
||||||
|
# Mark as finished with a special marker
|
||||||
|
event_deque.append("__STREAM_END__")
|
||||||
|
new_event_signal.set()
|
||||||
|
|
||||||
|
if response is not None and isinstance(response, ErrorResponse):
|
||||||
|
# If the request has failed, update the status to "failed".
|
||||||
|
response_id = request.request_id
|
||||||
|
async with self.response_store_lock:
|
||||||
|
stored_response = self.response_store.get(response_id)
|
||||||
|
assert stored_response is not None
|
||||||
|
if stored_response.status not in ("completed", "cancelled"):
|
||||||
|
stored_response.status = "failed"
|
||||||
|
|
||||||
async def _run_background_request(
|
async def _run_background_request(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
@ -759,9 +809,36 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
if stored_response.status not in ("completed", "cancelled"):
|
if stored_response.status not in ("completed", "cancelled"):
|
||||||
stored_response.status = "failed"
|
stored_response.status = "failed"
|
||||||
|
|
||||||
|
async def responses_background_stream_generator(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
starting_after: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if response_id not in self.event_store:
|
||||||
|
raise ValueError(f"Unknown response_id: {response_id}")
|
||||||
|
|
||||||
|
event_deque, new_event_signal = self.event_store[response_id]
|
||||||
|
start_index = 0 if starting_after is None else starting_after + 1
|
||||||
|
current_index = start_index
|
||||||
|
|
||||||
|
while True:
|
||||||
|
new_event_signal.clear()
|
||||||
|
|
||||||
|
# Yield existing events from start_index
|
||||||
|
while current_index < len(event_deque):
|
||||||
|
event = event_deque[current_index]
|
||||||
|
if event == "__STREAM_END__":
|
||||||
|
return
|
||||||
|
yield event
|
||||||
|
current_index += 1
|
||||||
|
|
||||||
|
await new_event_signal.wait()
|
||||||
|
|
||||||
async def retrieve_responses(
|
async def retrieve_responses(
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
|
starting_after: Optional[int],
|
||||||
|
stream: Optional[bool],
|
||||||
) -> Union[ErrorResponse, ResponsesResponse]:
|
) -> Union[ErrorResponse, ResponsesResponse]:
|
||||||
if not response_id.startswith("resp_"):
|
if not response_id.startswith("resp_"):
|
||||||
return self._make_invalid_id_error(response_id)
|
return self._make_invalid_id_error(response_id)
|
||||||
@ -771,6 +848,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
return self._make_not_found_error(response_id)
|
return self._make_not_found_error(response_id)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self.responses_background_stream_generator(
|
||||||
|
response_id,
|
||||||
|
starting_after,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def cancel_responses(
|
async def cancel_responses(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user