refactor zmq msg to object

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-03-09 13:00:25 +08:00
parent 912031ceb5
commit d35dace985
3 changed files with 152 additions and 139 deletions

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import json
import signal import signal
import sys import sys
import traceback import traceback
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Union from typing import Union
import uvicorn import uvicorn
@ -17,9 +17,8 @@ import zmq.asyncio
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR, from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest,
CONTENT_TYPE_JSON, ZmqMsgResponse)
CONTENT_TYPE_STREAM)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -28,6 +27,7 @@ logger = init_logger('vllm.entrypoints.disagg_connector')
TIME_OUT = 5 TIME_OUT = 5
X_REQUEST_ID_KEY = "X-Request-Id" X_REQUEST_ID_KEY = "X-Request-Id"
CONTENT_TYPE_STREAM = "text/event-stream"
# communication between output handlers and execute_task_async # communication between output handlers and execute_task_async
request_queues: dict[str, asyncio.Queue] request_queues: dict[str, asyncio.Queue]
@ -77,40 +77,44 @@ app = FastAPI(lifespan=lifespan)
@app.post('/v1/completions') @app.post('/v1/completions')
async def completions(request: Request, background_tasks: BackgroundTasks): async def completions(request: CompletionRequest, raw_request: Request,
background_tasks: BackgroundTasks):
try: try:
# Add the X-Request-Id header to the raw headers list # Add the X-Request-Id header to the raw headers list
header = dict(request.headers) header = dict(raw_request.headers)
request_id = header.get(X_REQUEST_ID_KEY) request_id = header.get(X_REQUEST_ID_KEY)
queue = asyncio.Queue() queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue()
if request_id is None: if request_id is None:
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
logger.debug("add X-Request-Id: %s", request_id) logger.debug("add X-Request-Id: %s", request_id)
header[X_REQUEST_ID_KEY] = request_id
logger.debug("X-Request-Id is: %s", request_id) logger.debug("X-Request-Id is: %s", request_id)
request_queues[request_id] = queue request_queues[request_id] = queue
request_data = await request.json() zmq_msg_request = ZmqMsgRequest(request_id=request_id,
type="completions",
body=request)
logger.info("Received request_id: %s, request: %s, header: %s", logger.info("Received request_id: %s, request: %s, header: %s",
request_id, request_data, header) request_id, zmq_msg_request.model_dump_json(), header)
original_max_tokens = request_data['max_tokens'] original_max_tokens = request.max_tokens
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
request_data['max_tokens'] = 1 request.max_tokens = 1
# finish prefill # finish prefill
try: try:
prefill_response = await prefill(header, request_data) prefill_response = await prefill(zmq_msg_request)
if isinstance(prefill_response, JSONResponse): if isinstance(prefill_response, JSONResponse
) and prefill_response.status_code != HTTPStatus.OK:
return prefill_response return prefill_response
logger.debug("finish prefill start decode") logger.debug("finish prefill start decode")
request_data['max_tokens'] = original_max_tokens request.max_tokens = original_max_tokens
response = await decode(header, request_data) response = await decode(zmq_msg_request)
logger.debug("finish decode") logger.debug("finish decode")
except Exception as e: except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s", logger.error("Error occurred in disagg prefill proxy server, %s",
e) e)
response = JSONResponse({"error": { response = JSONResponse(
"message": str(e) {"error": {
}}, "message": str(e)
status_code=500) }},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
return response return response
except Exception as e: except Exception as e:
@ -120,37 +124,27 @@ async def completions(request: Request, background_tasks: BackgroundTasks):
logger.error("".join(traceback.format_exception(*exc_info))) logger.error("".join(traceback.format_exception(*exc_info)))
response = JSONResponse({"error": { response = JSONResponse({"error": {
"message": str(e) "message": str(e)
}}, }}, HTTPStatus.INTERNAL_SERVER_ERROR)
status_code=500)
return response return response
finally: finally:
background_tasks.add_task(cleanup_request_id, request_id) if request_id is not None:
background_tasks.add_task(cleanup_request_id, request_id)
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str): async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
while True: while True:
try: try:
[request_id, contentType, reply] = await socket.recv_multipart() [body] = await socket.recv_multipart()
contentType_str = contentType.decode() response = ZmqMsgResponse.model_validate_json(body)
reply_str = reply.decode() request_id = response.request_id
request_id_str = request_id.decode() logger.debug("%s socket received result: %s", scene,
logger.debug( response.model_dump_json())
"%s socket received result contentType: %s, " if request_id in request_queues:
"request_id: %s, reply: %s", scene, contentType_str, request_queues[request_id].put_nowait(response)
request_id_str, reply_str)
if request_id_str in request_queues:
request_queues[request_id_str].put_nowait(
(contentType_str, reply_str))
if "[DONE]" in reply_str:
logger.debug(
"%s socket received stop signal request_id: %s", scene,
request_id_str)
request_queues[request_id_str].put_nowait(
(contentType_str, None))
else: else:
logger.debug( logger.debug(
"%s socket received but request_id not found discard: %s", "%s socket received but request_id not found discard: %s",
scene, request_id_str) scene, request_id)
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error("%s handler error: %s", scene, e) logger.error("%s handler error: %s", scene, e)
@ -167,78 +161,70 @@ async def decode_handler(decode_socket: zmq.asyncio.Socket):
# select a socket and execute task # select a socket and execute task
async def execute_task_async(headers: dict, request: dict, async def execute_task_async(zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket): socket: zmq.asyncio.Socket):
try: try:
request_id = headers.get(X_REQUEST_ID_KEY) request_id = zmq_msg_request.request_id
requestBody = json.dumps(request) requestBody = zmq_msg_request.model_dump_json()
logger.info("Sending requestBody: %s", requestBody) logger.debug("Sending requestBody: %s", requestBody)
socket.send_multipart([request_id.encode(), requestBody.encode()]) socket.send_multipart([requestBody.encode()])
logger.debug("Sent end") logger.debug("Sent end")
queue = request_queues[request_id] queue = request_queues[request_id]
while True: while True:
logger.debug("Waiting for reply") logger.debug("Waiting for reply")
(contentType, zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for(
reply) = await asyncio.wait_for(queue.get(), TIME_OUT) queue.get(), TIME_OUT)
logger.debug("Received result: %s, %s", contentType, reply) logger.debug("Received result: %s",
if reply is None: zmq_msg_response.model_dump_json())
logger.debug("Received stop signal, request_id: %s", yield zmq_msg_response
request_id) if zmq_msg_response.stop:
logger.debug("Received stop: %s", zmq_msg_response.stop)
break break
yield (contentType, reply)
if contentType == CONTENT_TYPE_JSON:
logger.debug("Received %s message, request_id: %s",
contentType, request_id)
break
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
yield (CONTENT_TYPE_ERROR, "System Error") yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT)
finally: finally:
logger.debug("request_id: %s, execute_task_async end", request_id) logger.debug("request_id: %s, execute_task_async end", request_id)
async def prefill(header: dict, async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]:
original_request_data: dict) -> Union[JSONResponse, bool]:
logger.debug("start prefill") logger.debug("start prefill")
generator = execute_task_async(header, original_request_data, generator = execute_task_async(zmq_msg_request, app.state.prefill_socket)
app.state.prefill_socket) async for res in generator:
async for contentType, reply in generator: logger.debug("res: %s", res)
logger.debug("contentType: %s, reply: %s", contentType, reply) if res.body_type == "response":
if contentType == CONTENT_TYPE_ERROR: return JSONResponse(res.body)
response = JSONResponse({"error": reply})
response.status_code = 500
return response
return True return True
async def generate_stream_response(fisrt_reply: str, async def generate_stream_response(
generator: AsyncGenerator): fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse]
) -> AsyncGenerator[dict, str]:
yield fisrt_reply yield fisrt_reply
async for _, reply in generator: async for reply in generator:
yield reply yield reply.body
async def decode( async def decode(
header: dict, zmq_msg_request: ZmqMsgRequest
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]: ) -> Union[JSONResponse, StreamingResponse]:
logger.info("start decode") logger.debug("start decode")
generator = execute_task_async(header, original_request_data, generator = execute_task_async(zmq_msg_request, app.state.decode_socket)
app.state.decode_socket)
async for contentType, reply in generator: async for res in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply) logger.debug("res: %s", res)
if contentType == CONTENT_TYPE_ERROR: if res.body_type == "response":
response = JSONResponse({"error": reply}) return JSONResponse(res.body)
response.status_code = 500
return response
elif contentType == CONTENT_TYPE_JSON:
return JSONResponse(reply)
else: else:
return StreamingResponse(generate_stream_response( return StreamingResponse(generate_stream_response(
reply, generator), res.body, generator),
media_type=CONTENT_TYPE_STREAM) media_type=CONTENT_TYPE_STREAM)
# If the generator is empty, return a default error response
logger.error("No response received from generator")
return JSONResponse({"error": "No response received from generator"},
status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
def cleanup_request_id(request_id: str): def cleanup_request_id(request_id: str):
if request_id in request_queues: if request_id in request_queues:

View File

@ -1649,3 +1649,19 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
words: Optional[list[TranscriptionWord]] = None words: Optional[list[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps.""" """Extracted words and their corresponding timestamps."""
class ZmqMsgRequest(BaseModel):
request_id: str
type: str
body: Union[CompletionRequest]
class ZmqMsgResponse(BaseModel):
request_id: str
type: str
stop: bool = True
body_type: Literal["str", "response"] = "str"
body: Union[dict, str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -1,22 +1,22 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import json
import os import os
import signal import signal
import traceback import traceback
from argparse import Namespace from argparse import Namespace
from http import HTTPStatus
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (CompletionRequest, from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse, CompletionResponse,
ErrorResponse) ErrorResponse, ZmqMsgRequest,
ZmqMsgResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
@ -26,10 +26,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger('vllm.entrypoints.openai.zmq_server') logger = init_logger('vllm.entrypoints.openai.zmq_server')
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_ERROR = "error"
CONTENT_TYPE_STREAM = "text/event-stream"
openai_serving_completion: OpenAIServingCompletion openai_serving_completion: OpenAIServingCompletion
openai_serving_models: OpenAIServingModels openai_serving_models: OpenAIServingModels
@ -72,16 +68,22 @@ async def serve_zmq(arg) -> None:
try: try:
logger.debug("zmq Server waiting for request") logger.debug("zmq Server waiting for request")
# get new request from the client # get new request from the client
identity, request_id, body = await socket.recv_multipart() message_parts = await socket.recv_multipart()
logger.debug("received request: %s", message_parts)
logger.debug("received len: %d", len(message_parts))
identity, body = message_parts[0], message_parts[1]
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
# launch request handler coroutine # launch request handler coroutine
task = asyncio.create_task( task = asyncio.create_task(
worker_routine(identity, request_id, body, socket)) worker_routine(identity, zmq_msg_request, socket))
running_requests.add(task) running_requests.add(task)
task.add_done_callback(running_requests.discard) task.add_done_callback(running_requests.discard)
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.error(traceback.format_exc())
logger.error("ZMQError: %s", e) logger.error("ZMQError: %s", e)
break break
except Exception as e: except Exception as e:
logger.error(traceback.format_exc())
logger.error("Unexpected error: %s", e) logger.error("Unexpected error: %s", e)
break break
except KeyboardInterrupt: except KeyboardInterrupt:
@ -153,59 +155,68 @@ async def init_state(
) )
async def worker_routine(identity: bytes, request_id: bytes, body: bytes, async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket): socket: zmq.asyncio.Socket):
"""Worker routine""" """Worker routine"""
try: try:
body_json = json.loads(body.decode()) request_id = zmq_msg_request.request_id
request_id_str = request_id.decode() logger.debug("receive request: %s from %s, request_id: %s",
logger.debug("receive request: %s from %s, request_id: %s", body_json, zmq_msg_request.model_dump_json(), identity.decode(),
identity.decode(), request_id_str) request_id)
if isinstance(zmq_msg_request.body, CompletionRequest):
completionRequest = CompletionRequest(**body_json) await create_completion(identity, zmq_msg_request, socket)
generator = await create_completion(completionRequest, None)
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
if isinstance(generator, ErrorResponse):
content = json.loads(generator.model_dump_json())
content.update({"status_code": generator.code})
logger.debug("send ErrorResponse %s", json.dumps(content))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(content).encode('utf-8')
])
elif isinstance(generator, CompletionResponse):
logger.debug("send CompletionResponse %s",
json.dumps(generator.model_dump()))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(generator.model_dump()).encode('utf-8')
])
else: else:
async for chunk in generator: logger.error("Error in worker routine: %s request_id: %s",
logger.debug( "unsupported request type", request_id)
"send chunk identity: %s, request_id: %s, chunk: %s", raise Exception("unsupported request type")
identity.decode(), request_id.decode(), chunk)
await socket.send_multipart([
identity, request_id, content_type_stream,
chunk.encode('utf-8')
])
except Exception as e: except Exception as e:
logger.error("Error in worker routine: %s request_id: %s", e, logger.error("Error in worker routine: %s request_id: %s", e,
request_id_str) request_id)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
logger.debug("send ErrorResponse %s", str(e)) logger.debug("send ErrorResponse %s", str(e))
await socket.send_multipart([ await socket.send_multipart([
identity, request_id, identity,
CONTENT_TYPE_ERROR.encode('utf-8'), ZmqMsgResponse(request_id=request_id,
str(e).encode('utf-8') type=zmq_msg_request.type,
body={
"content": "unsupported request type",
"status_code": HTTPStatus.INTERNAL_SERVER_ERROR
}).model_dump_json().encode(),
]) ])
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
logger.debug("zmq request post: %s", request) socket: zmq.asyncio.Socket):
generator = await openai_serving_completion.create_completion( request: CompletionRequest = zmq_msg_request.body
request, raw_request) logger.debug("zmq request post: %s", request.model_dump_json())
generator = await openai_serving_completion.create_completion(request)
logger.debug("zmq request end post") logger.debug("zmq request end post")
return generator request_id = zmq_msg_request.request_id
if isinstance(generator, (ErrorResponse, CompletionResponse)):
logger.debug("send response %s", generator.model_dump_json())
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body_type="response")
if isinstance(generator, ErrorResponse):
zmq_msg_response.body = {
"content": generator.model_dump(),
"status_code": generator.code
}
elif isinstance(generator, CompletionResponse):
zmq_msg_response.body = {"content": generator.model_dump()}
await socket.send_multipart(
[identity, zmq_msg_response.model_dump_json().encode()])
else:
async for chunk in generator:
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body=chunk)
if "data: [DONE]" not in chunk:
zmq_msg_response.stop = False
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
identity.decode(), request_id, chunk)
await socket.send_multipart(
[identity,
zmq_msg_response.model_dump_json().encode()])