From d35dace985c0d55ad9c6cbbc4235b89bd8d7d928 Mon Sep 17 00:00:00 2001 From: clark Date: Sun, 9 Mar 2025 13:00:25 +0800 Subject: [PATCH] refactor zmq msg to object Signed-off-by: clark --- vllm/entrypoints/disagg_connector.py | 160 ++++++++++++-------------- vllm/entrypoints/openai/protocol.py | 16 +++ vllm/entrypoints/openai/zmq_server.py | 115 +++++++++--------- 3 files changed, 152 insertions(+), 139 deletions(-) diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index cb6549dd40bc2..245188c3cf72d 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import json import signal import sys import traceback import uuid from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from http import HTTPStatus from typing import Union import uvicorn @@ -17,9 +17,8 @@ import zmq.asyncio from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR, - CONTENT_TYPE_JSON, - CONTENT_TYPE_STREAM) +from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest, + ZmqMsgResponse) from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -28,6 +27,7 @@ logger = init_logger('vllm.entrypoints.disagg_connector') TIME_OUT = 5 X_REQUEST_ID_KEY = "X-Request-Id" +CONTENT_TYPE_STREAM = "text/event-stream" # communication between output handlers and execute_task_async request_queues: dict[str, asyncio.Queue] @@ -77,40 +77,44 @@ app = FastAPI(lifespan=lifespan) @app.post('/v1/completions') -async def completions(request: Request, background_tasks: BackgroundTasks): +async def completions(request: CompletionRequest, raw_request: Request, + background_tasks: BackgroundTasks): try: # Add the X-Request-Id header to the raw headers list - header = dict(request.headers) + header = dict(raw_request.headers) request_id = header.get(X_REQUEST_ID_KEY) - queue = asyncio.Queue() + queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue() if request_id is None: request_id = str(uuid.uuid4()) logger.debug("add X-Request-Id: %s", request_id) - header[X_REQUEST_ID_KEY] = request_id logger.debug("X-Request-Id is: %s", request_id) request_queues[request_id] = queue - request_data = await request.json() + zmq_msg_request = ZmqMsgRequest(request_id=request_id, + type="completions", + body=request) logger.info("Received request_id: %s, request: %s, header: %s", - request_id, request_data, header) - original_max_tokens = request_data['max_tokens'] + request_id, zmq_msg_request.model_dump_json(), header) + original_max_tokens = request.max_tokens # change max_tokens = 1 to let it only do prefill - request_data['max_tokens'] = 1 + request.max_tokens = 1 # finish prefill try: - prefill_response = await prefill(header, request_data) - if isinstance(prefill_response, JSONResponse): + prefill_response = await prefill(zmq_msg_request) + if isinstance(prefill_response, JSONResponse + ) and prefill_response.status_code != HTTPStatus.OK: return prefill_response logger.debug("finish prefill start decode") - request_data['max_tokens'] = original_max_tokens - response = await decode(header, request_data) + request.max_tokens = original_max_tokens + response = await decode(zmq_msg_request) logger.debug("finish decode") except Exception as e: logger.error("Error occurred in disagg prefill proxy server, %s", e) - response = JSONResponse({"error": { - "message": str(e) - }}, - status_code=500) + response = JSONResponse( + {"error": { + "message": str(e) + }}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR) return response except Exception as e: @@ -120,37 +124,27 @@ async def completions(request: Request, background_tasks: BackgroundTasks): logger.error("".join(traceback.format_exception(*exc_info))) response = JSONResponse({"error": { "message": str(e) - }}, - status_code=500) + }}, HTTPStatus.INTERNAL_SERVER_ERROR) return response finally: - background_tasks.add_task(cleanup_request_id, request_id) + if request_id is not None: + background_tasks.add_task(cleanup_request_id, request_id) async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str): while True: try: - [request_id, contentType, reply] = await socket.recv_multipart() - contentType_str = contentType.decode() - reply_str = reply.decode() - request_id_str = request_id.decode() - logger.debug( - "%s socket received result contentType: %s, " - "request_id: %s, reply: %s", scene, contentType_str, - request_id_str, reply_str) - if request_id_str in request_queues: - request_queues[request_id_str].put_nowait( - (contentType_str, reply_str)) - if "[DONE]" in reply_str: - logger.debug( - "%s socket received stop signal request_id: %s", scene, - request_id_str) - request_queues[request_id_str].put_nowait( - (contentType_str, None)) + [body] = await socket.recv_multipart() + response = ZmqMsgResponse.model_validate_json(body) + request_id = response.request_id + logger.debug("%s socket received result: %s", scene, + response.model_dump_json()) + if request_id in request_queues: + request_queues[request_id].put_nowait(response) else: logger.debug( "%s socket received but request_id not found discard: %s", - scene, request_id_str) + scene, request_id) except Exception as e: logger.error(traceback.format_exc()) logger.error("%s handler error: %s", scene, e) @@ -167,78 +161,70 @@ async def decode_handler(decode_socket: zmq.asyncio.Socket): # select a socket and execute task -async def execute_task_async(headers: dict, request: dict, +async def execute_task_async(zmq_msg_request: ZmqMsgRequest, socket: zmq.asyncio.Socket): try: - request_id = headers.get(X_REQUEST_ID_KEY) - requestBody = json.dumps(request) - logger.info("Sending requestBody: %s", requestBody) - socket.send_multipart([request_id.encode(), requestBody.encode()]) + request_id = zmq_msg_request.request_id + requestBody = zmq_msg_request.model_dump_json() + logger.debug("Sending requestBody: %s", requestBody) + socket.send_multipart([requestBody.encode()]) logger.debug("Sent end") queue = request_queues[request_id] while True: logger.debug("Waiting for reply") - (contentType, - reply) = await asyncio.wait_for(queue.get(), TIME_OUT) - logger.debug("Received result: %s, %s", contentType, reply) - if reply is None: - logger.debug("Received stop signal, request_id: %s", - request_id) + zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for( + queue.get(), TIME_OUT) + logger.debug("Received result: %s", + zmq_msg_response.model_dump_json()) + yield zmq_msg_response + if zmq_msg_response.stop: + logger.debug("Received stop: %s", zmq_msg_response.stop) break - yield (contentType, reply) - if contentType == CONTENT_TYPE_JSON: - logger.debug("Received %s message, request_id: %s", - contentType, request_id) - break - except asyncio.TimeoutError: logger.error(traceback.format_exc()) - yield (CONTENT_TYPE_ERROR, "System Error") + yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT) finally: logger.debug("request_id: %s, execute_task_async end", request_id) -async def prefill(header: dict, - original_request_data: dict) -> Union[JSONResponse, bool]: +async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]: logger.debug("start prefill") - generator = execute_task_async(header, original_request_data, - app.state.prefill_socket) - async for contentType, reply in generator: - logger.debug("contentType: %s, reply: %s", contentType, reply) - if contentType == CONTENT_TYPE_ERROR: - response = JSONResponse({"error": reply}) - response.status_code = 500 - return response + generator = execute_task_async(zmq_msg_request, app.state.prefill_socket) + async for res in generator: + logger.debug("res: %s", res) + if res.body_type == "response": + return JSONResponse(res.body) return True -async def generate_stream_response(fisrt_reply: str, - generator: AsyncGenerator): +async def generate_stream_response( + fisrt_reply: str, generator: AsyncGenerator[ZmqMsgResponse] +) -> AsyncGenerator[dict, str]: yield fisrt_reply - async for _, reply in generator: - yield reply + async for reply in generator: + yield reply.body async def decode( - header: dict, - original_request_data: dict) -> Union[JSONResponse, StreamingResponse]: - logger.info("start decode") - generator = execute_task_async(header, original_request_data, - app.state.decode_socket) + zmq_msg_request: ZmqMsgRequest +) -> Union[JSONResponse, StreamingResponse]: + logger.debug("start decode") + generator = execute_task_async(zmq_msg_request, app.state.decode_socket) - async for contentType, reply in generator: - logger.debug("contentType: %s, reply: %s", contentType, reply) - if contentType == CONTENT_TYPE_ERROR: - response = JSONResponse({"error": reply}) - response.status_code = 500 - return response - elif contentType == CONTENT_TYPE_JSON: - return JSONResponse(reply) + async for res in generator: + logger.debug("res: %s", res) + if res.body_type == "response": + return JSONResponse(res.body) else: return StreamingResponse(generate_stream_response( - reply, generator), + res.body, generator), media_type=CONTENT_TYPE_STREAM) + # If the generator is empty, return a default error response + logger.error("No response received from generator") + return JSONResponse({"error": "No response received from generator"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + def cleanup_request_id(request_id: str): if request_id in request_queues: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a96ca1f757008..a91cae3a34a7c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1649,3 +1649,19 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[list[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" + + +class ZmqMsgRequest(BaseModel): + request_id: str + type: str + body: Union[CompletionRequest] + + +class ZmqMsgResponse(BaseModel): + request_id: str + type: str + stop: bool = True + body_type: Literal["str", "response"] = "str" + body: Union[dict, str] = None + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/vllm/entrypoints/openai/zmq_server.py b/vllm/entrypoints/openai/zmq_server.py index b5404a7d766ae..b1ccd20e831e3 100644 --- a/vllm/entrypoints/openai/zmq_server.py +++ b/vllm/entrypoints/openai/zmq_server.py @@ -1,22 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import json import os import signal import traceback from argparse import Namespace +from http import HTTPStatus import zmq import zmq.asyncio -from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, - ErrorResponse) + ErrorResponse, ZmqMsgRequest, + ZmqMsgResponse) from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) @@ -26,10 +26,6 @@ from vllm.version import __version__ as VLLM_VERSION logger = init_logger('vllm.entrypoints.openai.zmq_server') -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_ERROR = "error" -CONTENT_TYPE_STREAM = "text/event-stream" - openai_serving_completion: OpenAIServingCompletion openai_serving_models: OpenAIServingModels @@ -72,16 +68,22 @@ async def serve_zmq(arg) -> None: try: logger.debug("zmq Server waiting for request") # get new request from the client - identity, request_id, body = await socket.recv_multipart() + message_parts = await socket.recv_multipart() + logger.debug("received request: %s", message_parts) + logger.debug("received len: %d", len(message_parts)) + identity, body = message_parts[0], message_parts[1] + zmq_msg_request = ZmqMsgRequest.model_validate_json(body) # launch request handler coroutine task = asyncio.create_task( - worker_routine(identity, request_id, body, socket)) + worker_routine(identity, zmq_msg_request, socket)) running_requests.add(task) task.add_done_callback(running_requests.discard) except zmq.ZMQError as e: + logger.error(traceback.format_exc()) logger.error("ZMQError: %s", e) break except Exception as e: + logger.error(traceback.format_exc()) logger.error("Unexpected error: %s", e) break except KeyboardInterrupt: @@ -153,59 +155,68 @@ async def init_state( ) -async def worker_routine(identity: bytes, request_id: bytes, body: bytes, +async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest, socket: zmq.asyncio.Socket): """Worker routine""" try: - body_json = json.loads(body.decode()) - request_id_str = request_id.decode() - logger.debug("receive request: %s from %s, request_id: %s", body_json, - identity.decode(), request_id_str) - - completionRequest = CompletionRequest(**body_json) - generator = await create_completion(completionRequest, None) - content_type_json = CONTENT_TYPE_JSON.encode('utf-8') - content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8') - if isinstance(generator, ErrorResponse): - content = json.loads(generator.model_dump_json()) - content.update({"status_code": generator.code}) - logger.debug("send ErrorResponse %s", json.dumps(content)) - await socket.send_multipart([ - identity, request_id, content_type_json, - json.dumps(content).encode('utf-8') - ]) - elif isinstance(generator, CompletionResponse): - logger.debug("send CompletionResponse %s", - json.dumps(generator.model_dump())) - await socket.send_multipart([ - identity, request_id, content_type_json, - json.dumps(generator.model_dump()).encode('utf-8') - ]) + request_id = zmq_msg_request.request_id + logger.debug("receive request: %s from %s, request_id: %s", + zmq_msg_request.model_dump_json(), identity.decode(), + request_id) + if isinstance(zmq_msg_request.body, CompletionRequest): + await create_completion(identity, zmq_msg_request, socket) else: - async for chunk in generator: - logger.debug( - "send chunk identity: %s, request_id: %s, chunk: %s", - identity.decode(), request_id.decode(), chunk) - await socket.send_multipart([ - identity, request_id, content_type_stream, - chunk.encode('utf-8') - ]) + logger.error("Error in worker routine: %s request_id: %s", + "unsupported request type", request_id) + raise Exception("unsupported request type") + except Exception as e: logger.error("Error in worker routine: %s request_id: %s", e, - request_id_str) + request_id) logger.error(traceback.format_exc()) - content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8') logger.debug("send ErrorResponse %s", str(e)) await socket.send_multipart([ - identity, request_id, - CONTENT_TYPE_ERROR.encode('utf-8'), - str(e).encode('utf-8') + identity, + ZmqMsgResponse(request_id=request_id, + type=zmq_msg_request.type, + body={ + "content": "unsupported request type", + "status_code": HTTPStatus.INTERNAL_SERVER_ERROR + }).model_dump_json().encode(), ]) -async def create_completion(request: CompletionRequest, raw_request: Request): - logger.debug("zmq request post: %s", request) - generator = await openai_serving_completion.create_completion( - request, raw_request) +async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest, + socket: zmq.asyncio.Socket): + request: CompletionRequest = zmq_msg_request.body + logger.debug("zmq request post: %s", request.model_dump_json()) + generator = await openai_serving_completion.create_completion(request) logger.debug("zmq request end post") - return generator + request_id = zmq_msg_request.request_id + if isinstance(generator, (ErrorResponse, CompletionResponse)): + logger.debug("send response %s", generator.model_dump_json()) + zmq_msg_response = ZmqMsgResponse(request_id=request_id, + type=zmq_msg_request.type, + body_type="response") + if isinstance(generator, ErrorResponse): + zmq_msg_response.body = { + "content": generator.model_dump(), + "status_code": generator.code + } + elif isinstance(generator, CompletionResponse): + zmq_msg_response.body = {"content": generator.model_dump()} + + await socket.send_multipart( + [identity, zmq_msg_response.model_dump_json().encode()]) + else: + async for chunk in generator: + zmq_msg_response = ZmqMsgResponse(request_id=request_id, + type=zmq_msg_request.type, + body=chunk) + if "data: [DONE]" not in chunk: + zmq_msg_response.stop = False + logger.debug("send chunk identity: %s, request_id: %s, chunk: %s", + identity.decode(), request_id, chunk) + await socket.send_multipart( + [identity, + zmq_msg_response.model_dump_json().encode()])