diff --git a/vllm/entrypoints/connect.py b/vllm/entrypoints/connect.py index a7bb5a9daccf7..439bfa5088841 100644 --- a/vllm/entrypoints/connect.py +++ b/vllm/entrypoints/connect.py @@ -50,22 +50,22 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con return sockets # select a scoket and execute task -async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list): +async def execute_task_async(route: str, headers: dict, request: dict, sockets: list): sock = await sockets.get() try: requestBody = json.dumps(request) - headersJson = json.dumps(dict(headers)) + headersJson = json.dumps(headers) logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}") await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()]) logger.info(f"Sent end") while True: logger.info(f"Waiting for reply") - reply = await sock.recv_multipart() - logger.info(f"Received result: {reply}") - yield f"data: {reply[0].decode()}\n\n" - if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode(): + [contentType, reply] = await sock.recv_multipart() + logger.info(f"Received result: {contentType}, {reply}") + reply = reply.decode() + yield f"{reply}" + if "[DONE]" in reply: logger.info(f"Received stop signal, return socket") - yield "data: [DONE]\n\n" break finally: await sockets.put(sock) @@ -73,16 +73,20 @@ async def execute_task_async(route: str, headers: Headers, request: dict, socket @app.post('/v1/connect/completions') async def chat_completions(request: Request): try: + # Add the X-Request-Id header to the raw headers list + x_request_id = str(uuid.uuid4()) + header = dict(request.headers) + if header.get("X-Request-Id") is None: + logger.info(f"add X-Request-Id: {x_request_id}") + header["X-Request-Id"] = x_request_id original_request_data = await request.json() - header = request.headers - logger.info(f"Received request: {original_request_data}") + logger.info(f"Received request: {original_request_data} header: {header}") prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 route = "/v1/completions" # finish prefill async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill): - logger.info(f"{x}") continue # return decode diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index b09ee526f14ae..dd098693cf36d 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -6,9 +6,13 @@ import socket from http import HTTPStatus from typing import Any, Optional +import zmq +import zmq.asyncio + import uvicorn from fastapi import FastAPI, Request, Response +from vllm.entrypoints.openai.connect_worker import worker_routine from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError @@ -72,8 +76,39 @@ async def serve_http(app: FastAPI, "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") - return server.shutdown() + return server.shutdown() +async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: + """Server routine""" + logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}") + url_worker = "inproc://workers" + url_client = f"tcp://0.0.0.0:{zmq_server_port}" + # Prepare our context and sockets + context = zmq.asyncio.Context() + + # Socket to talk to clients + clients = context.socket(zmq.ROUTER) + clients.bind(url_client) + logger.info(f"ZMQ Server ROUTER started at {url_client}") + # Socket to talk to workers + workers = context.socket(zmq.DEALER) + workers.bind(url_worker) + logger.info(f"ZMQ Worker DEALER started at {url_worker}") + + tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)] + proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) + + try: + await asyncio.gather(*tasks, proxy_task) + except KeyboardInterrupt: + print("ZMQ Server interrupted") + except zmq.ZMQError as e: + print("ZMQError:", e) + finally: + # We never get here but clean up anyhow + clients.close() + workers.close() + context.term() def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f9b1d69a31d8c..2c67cc1644eb9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.launcher import serve_http, serve_zmq from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) @@ -1029,6 +1029,11 @@ async def run_server(args, **uvicorn_kwargs) -> None: "s" if is_ssl else "", _listen_addr(sock_addr[0]), sock_addr[1]) + zmq_server_port = args.zmq_server_port + if zmq_server_port is not None: + logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port) + asyncio.create_task(serve_zmq(args, zmq_server_port, app)) + shutdown_task = await serve_http( app, sock=sock, diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py new file mode 100644 index 0000000000000..0a641d423a58f --- /dev/null +++ b/vllm/entrypoints/openai/connect_worker.py @@ -0,0 +1,129 @@ +import json +from typing import Optional +import zmq +import zmq.asyncio +import tempfile +import uuid +import httpx +import json + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionRequest, + CompletionRequest, + CompletionResponse, + ErrorResponse) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization +from vllm.logger import init_logger +import traceback + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.connect_worker') + +def base(app: FastAPI) -> OpenAIServing: + # Reuse the existing instance + return tokenization(app) + + +def models(app: FastAPI) -> OpenAIServingModels: + return app.state.openai_serving_models + + +def chat(app: FastAPI) -> Optional[OpenAIServingChat]: + return app.state.openai_serving_chat + + +def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]: + return app.state.openai_serving_completion + +def tokenization(app: FastAPI) -> OpenAIServingTokenization: + return app.state.openai_serving_tokenization + + +def bytes_to_headers(bytes_data: bytes) -> httpx.Headers: + headers_dict = json.loads(bytes_data.decode()) + return httpx.Headers(headers_dict) + +async def worker_routine(worker_url: str, app: FastAPI, + context: zmq.asyncio.Context = None, i: int = 0): + """Worker routine""" + try: + # Socket to talk to dispatcher + socket = context.socket(zmq.DEALER) + worker_identity = f"worker-{i}-{uuid.uuid4()}" + socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) + socket.connect(worker_url) + logger.info(f"{worker_identity} started at {worker_url}") + while True: + identity, url, header, body = await socket.recv_multipart() + logger.info(f"worker-{i} Received request identity: [{identity} ]") + url = url.decode() + logger.info(f"worker-{i} Received request url: [{url} ]") + header = bytes_to_headers(header) + logger.info(f"worker-{i} Received request headers: [{header} ]") + body = json.loads(body.decode()) + logger.info(f"worker-{i} Received request body: [{body} ]") + logger.info(f"worker-{i} Calling OpenAI API") + completionRequest = CompletionRequest(**body) + createRequest = create_request(url, "POST", body, header) + generator = await create_completion(app, completionRequest, createRequest) + logger.info(f"worker-{i} Received response: [{generator} ]") + if isinstance(generator, ErrorResponse): + content = generator.model_dump_json() + context = json.loads(content) + context.append("status_code", generator.code) + await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()]) + elif isinstance(generator, CompletionResponse): + await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())]) + else: + async for chunk in generator: + logger.info(f"worker-{i} Sending response chunk: [{chunk} ]") + await socket.send_multipart([identity, b"text/event-stream", chunk.encode()]) + except Exception as e: + logger.error(f"Error in worker routine: {e} worker-{i}") + logger.error(traceback.format_exc()) + +async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request): + handler = completion(app) + logger.info(f"zmq requset post: {request}") + if handler is None: + return base(app).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) + logger.info(f"zmq requset end post: {generator}") + return generator + + +def create_request(path: str, method: str, body: bytes, headers: dict = None): + scope = { + 'type': 'http', + 'http_version': '1.1', + 'method': method, + 'path': path, + 'headers': list(headers.items()) if headers else [], + } + if body: + scope['body'] = json.dumps(body).encode('utf-8') + async def receive(): + return { + 'type': 'http.request', + 'body': scope.get('body', b''), + } + async def send(message): + pass + return Request(scope, receive=receive, send=send) + + +if __name__ == "__main__": + print(bytes_to_headers(b'{"Content-Type": "application/json"}')) +