refactor zmq msg to object

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-03-09 13:00:25 +08:00
parent 912031ceb5
commit d35dace985
3 changed files with 152 additions and 139 deletions

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import signal
import sys
import traceback
import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Union
import uvicorn
@ -17,9 +17,8 @@ import zmq.asyncio
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR,
CONTENT_TYPE_JSON,
CONTENT_TYPE_STREAM)
from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
ZmqMsgResponse)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
@ -28,6 +27,7 @@ logger = init_logger('vllm.entrypoints.disagg_connector')
TIME_OUT = 5
X_REQUEST_ID_KEY = "X-Request-Id"
CONTENT_TYPE_STREAM = "text/event-stream"
# communication between output handlers and execute_task_async
request_queues: dict[str, asyncio.Queue]
@ -77,40 +77,44 @@ app = FastAPI(lifespan=lifespan)
@app.post('/v1/completions')
async def completions(request: Request, background_tasks: BackgroundTasks):
async def completions(request: CompletionRequest, raw_request: Request,
background_tasks: BackgroundTasks):
try:
# Add the X-Request-Id header to the raw headers list
header = dict(request.headers)
header = dict(raw_request.headers)
request_id = header.get(X_REQUEST_ID_KEY)
queue = asyncio.Queue()
queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
if request_id is None:
request_id = str(uuid.uuid4())
logger.debug("add X-Request-Id: %s", request_id)
header[X_REQUEST_ID_KEY] = request_id
logger.debug("X-Request-Id is: %s", request_id)
request_queues[request_id] = queue
request_data = await request.json()
zmq_msg_request = ZmqMsgRequest(request_id=request_id,
type="completions",
body=request)
logger.info("Received request_id: %s, request: %s, header: %s",
request_id, request_data, header)
original_max_tokens = request_data['max_tokens']
request_id, zmq_msg_request.model_dump_json(), header)
original_max_tokens = request.max_tokens
# change max_tokens = 1 to let it only do prefill
request_data['max_tokens'] = 1
request.max_tokens = 1
# finish prefill
try:
prefill_response = await prefill(header, request_data)
if isinstance(prefill_response, JSONResponse):
prefill_response = await prefill(zmq_msg_request)
if isinstance(prefill_response, JSONResponse
) and prefill_response.status_code != HTTPStatus.OK:
return prefill_response
logger.debug("finish prefill start decode")
request_data['max_tokens'] = original_max_tokens
response = await decode(header, request_data)
request.max_tokens = original_max_tokens
response = await decode(zmq_msg_request)
logger.debug("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse({"error": {
"message": str(e)
}},
status_code=500)
response = JSONResponse(
{"error": {
"message": str(e)
}},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
return response
except Exception as e:
@ -120,37 +124,27 @@ async def completions(request: Request, background_tasks: BackgroundTasks):
logger.error("".join(traceback.format_exception(*exc_info)))
response = JSONResponse({"error": {
"message": str(e)
}},
status_code=500)
}}, HTTPStatus.INTERNAL_SERVER_ERROR)
return response
finally:
background_tasks.add_task(cleanup_request_id, request_id)
if request_id is not None:
background_tasks.add_task(cleanup_request_id, request_id)
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
while True:
try:
[request_id, contentType, reply] = await socket.recv_multipart()
contentType_str = contentType.decode()
reply_str = reply.decode()
request_id_str = request_id.decode()
logger.debug(
"%s socket received result contentType: %s, "
"request_id: %s, reply: %s", scene, contentType_str,
request_id_str, reply_str)
if request_id_str in request_queues:
request_queues[request_id_str].put_nowait(
(contentType_str, reply_str))
if "[DONE]" in reply_str:
logger.debug(
"%s socket received stop signal request_id: %s", scene,
request_id_str)
request_queues[request_id_str].put_nowait(
(contentType_str, None))
[body] = await socket.recv_multipart()
response = ZmqMsgResponse.model_validate_json(body)
request_id = response.request_id
logger.debug("%s socket received result: %s", scene,
response.model_dump_json())
if request_id in request_queues:
request_queues[request_id].put_nowait(response)
else:
logger.debug(
"%s socket received but request_id not found discard: %s",
scene, request_id_str)
scene, request_id)
except Exception as e:
logger.error(traceback.format_exc())
logger.error("%s handler error: %s", scene, e)
@ -167,78 +161,70 @@ async def decode_handler(decode_socket: zmq.asyncio.Socket):
# select a socket and execute task
async def execute_task_async(headers: dict, request: dict,
async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
try:
request_id = headers.get(X_REQUEST_ID_KEY)
requestBody = json.dumps(request)
logger.info("Sending requestBody: %s", requestBody)
socket.send_multipart([request_id.encode(), requestBody.encode()])
request_id = zmq_msg_request.request_id
requestBody = zmq_msg_request.model_dump_json()
logger.debug("Sending requestBody: %s", requestBody)
socket.send_multipart([requestBody.encode()])
logger.debug("Sent end")
queue = request_queues[request_id]
while True:
logger.debug("Waiting for reply")
(contentType,
reply) = await asyncio.wait_for(queue.get(), TIME_OUT)
logger.debug("Received result: %s, %s", contentType, reply)
if reply is None:
logger.debug("Received stop signal, request_id: %s",
request_id)
zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
queue.get(), TIME_OUT)
logger.debug("Received result: %s",
zmq_msg_response.model_dump_json())
yield zmq_msg_response
if zmq_msg_response.stop:
logger.debug("Received stop: %s", zmq_msg_response.stop)
break
yield (contentType, reply)
if contentType == CONTENT_TYPE_JSON:
logger.debug("Received %s message, request_id: %s",
contentType, request_id)
break
except asyncio.TimeoutError:
logger.error(traceback.format_exc())
yield (CONTENT_TYPE_ERROR, "System Error")
yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
finally:
logger.debug("request_id: %s, execute_task_async end", request_id)
async def prefill(header: dict,
original_request_data: dict) -> Union[JSONResponse, bool]:
async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
logger.debug("start prefill")
generator = execute_task_async(header, original_request_data,
app.state.prefill_socket)
async for contentType, reply in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply)
if contentType == CONTENT_TYPE_ERROR:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
async for res in generator:
logger.debug("res: %s", res)
if res.body_type == "response":
return JSONResponse(res.body)
return True
async def generate_stream_response(fisrt_reply: str,
generator: AsyncGenerator):
async def generate_stream_response(
fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse]
) -> AsyncGenerator[dict, str]:
yield fisrt_reply
async for _, reply in generator:
yield reply
async for reply in generator:
yield reply.body
async def decode(
header: dict,
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]:
logger.info("start decode")
generator = execute_task_async(header, original_request_data,
app.state.decode_socket)
zmq_msg_request: ZmqMsgRequest
) -> Union[JSONResponse, StreamingResponse]:
logger.debug("start decode")
generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
async for contentType, reply in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply)
if contentType == CONTENT_TYPE_ERROR:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
elif contentType == CONTENT_TYPE_JSON:
return JSONResponse(reply)
async for res in generator:
logger.debug("res: %s", res)
if res.body_type == "response":
return JSONResponse(res.body)
else:
return StreamingResponse(generate_stream_response(
reply, generator),
res.body, generator),
media_type=CONTENT_TYPE_STREAM)
# If the generator is empty, return a default error response
logger.error("No response received from generator")
return JSONResponse({"error": "No response received from generator"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
def cleanup_request_id(request_id: str):
if request_id in request_queues:

View File

@ -1649,3 +1649,19 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
words: Optional[list[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps."""
class ZmqMsgRequest(BaseModel):
request_id: str
type: str
body: Union[CompletionRequest]
class ZmqMsgResponse(BaseModel):
request_id: str
type: str
stop: bool = True
body_type: Literal["str", "response"] = "str"
body: Union[dict, str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -1,22 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import os
import signal
import traceback
from argparse import Namespace
from http import HTTPStatus
import zmq
import zmq.asyncio
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
ErrorResponse, ZmqMsgRequest,
ZmqMsgResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
@ -26,10 +26,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger('vllm.entrypoints.openai.zmq_server')
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_ERROR = "error"
CONTENT_TYPE_STREAM = "text/event-stream"
openai_serving_completion: OpenAIServingCompletion
openai_serving_models: OpenAIServingModels
@ -72,16 +68,22 @@ async def serve_zmq(arg) -> None:
try:
logger.debug("zmq Server waiting for request")
# get new request from the client
identity, request_id, body = await socket.recv_multipart()
message_parts = await socket.recv_multipart()
logger.debug("received request: %s", message_parts)
logger.debug("received len: %d", len(message_parts))
identity, body = message_parts[0], message_parts[1]
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
# launch request handler coroutine
task = asyncio.create_task(
worker_routine(identity, request_id, body, socket))
worker_routine(identity, zmq_msg_request, socket))
running_requests.add(task)
task.add_done_callback(running_requests.discard)
except zmq.ZMQError as e:
logger.error(traceback.format_exc())
logger.error("ZMQError: %s", e)
break
except Exception as e:
logger.error(traceback.format_exc())
logger.error("Unexpected error: %s", e)
break
except KeyboardInterrupt:
@ -153,59 +155,68 @@ async def init_state(
)
async def worker_routine(identity: bytes, request_id: bytes, body: bytes,
async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
"""Worker routine"""
try:
body_json = json.loads(body.decode())
request_id_str = request_id.decode()
logger.debug("receive request: %s from %s, request_id: %s", body_json,
identity.decode(), request_id_str)
completionRequest = CompletionRequest(**body_json)
generator = await create_completion(completionRequest, None)
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
if isinstance(generator, ErrorResponse):
content = json.loads(generator.model_dump_json())
content.update({"status_code": generator.code})
logger.debug("send ErrorResponse %s", json.dumps(content))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(content).encode('utf-8')
])
elif isinstance(generator, CompletionResponse):
logger.debug("send CompletionResponse %s",
json.dumps(generator.model_dump()))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(generator.model_dump()).encode('utf-8')
])
request_id = zmq_msg_request.request_id
logger.debug("receive request: %s from %s, request_id: %s",
zmq_msg_request.model_dump_json(), identity.decode(),
request_id)
if isinstance(zmq_msg_request.body, CompletionRequest):
await create_completion(identity, zmq_msg_request, socket)
else:
async for chunk in generator:
logger.debug(
"send chunk identity: %s, request_id: %s, chunk: %s",
identity.decode(), request_id.decode(), chunk)
await socket.send_multipart([
identity, request_id, content_type_stream,
chunk.encode('utf-8')
])
logger.error("Error in worker routine: %s request_id: %s",
"unsupported request type", request_id)
raise Exception("unsupported request type")
except Exception as e:
logger.error("Error in worker routine: %s request_id: %s", e,
request_id_str)
request_id)
logger.error(traceback.format_exc())
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
logger.debug("send ErrorResponse %s", str(e))
await socket.send_multipart([
identity, request_id,
CONTENT_TYPE_ERROR.encode('utf-8'),
str(e).encode('utf-8')
identity,
ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body={
"content": "unsupported request type",
"status_code": HTTPStatus.INTERNAL_SERVER_ERROR
}).model_dump_json().encode(),
])
async def create_completion(request: CompletionRequest, raw_request: Request):
logger.debug("zmq request post: %s", request)
generator = await openai_serving_completion.create_completion(
request, raw_request)
async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
request: CompletionRequest = zmq_msg_request.body
logger.debug("zmq request post: %s", request.model_dump_json())
generator = await openai_serving_completion.create_completion(request)
logger.debug("zmq request end post")
return generator
request_id = zmq_msg_request.request_id
if isinstance(generator, (ErrorResponse, CompletionResponse)):
logger.debug("send response %s", generator.model_dump_json())
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body_type="response")
if isinstance(generator, ErrorResponse):
zmq_msg_response.body = {
"content": generator.model_dump(),
"status_code": generator.code
}
elif isinstance(generator, CompletionResponse):
zmq_msg_response.body = {"content": generator.model_dump()}
await socket.send_multipart(
[identity, zmq_msg_response.model_dump_json().encode()])
else:
async for chunk in generator:
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body=chunk)
if "data: [DONE]" not in chunk:
zmq_msg_response.stop = False
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
identity.decode(), request_id, chunk)
await socket.send_multipart(
[identity,
zmq_msg_response.model_dump_json().encode()])