diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 342f5ce3b6549..2b2749680f625 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -4,7 +4,6 @@ import uvicorn import zmq import zmq.asyncio from fastapi import FastAPI, Request -from starlette.datastructures import Headers from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager # from fastapi.lifespan import Lifespan @@ -24,7 +23,7 @@ logger = init_logger('vllm.entrypoints.connect') @asynccontextmanager async def lifespan(app: FastAPI): - # create scoket pool with prefill and decode + # create socket pool with prefill and decode logger.info("start create_socket_pool") app.state.zmqctx = zmq.asyncio.Context() app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) @@ -39,7 +38,7 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) # create async socket pool with num_sockets use ZMQ_DEALER -async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context): +async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue: sockets = Queue() for i in range(num_sockets): sock = zmqctx.socket(zmq.DEALER) @@ -50,8 +49,8 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con await sockets.put(sock) return sockets -# select a scoket and execute task -async def execute_task_async(route: str, headers: dict, request: dict, sockets: list): +# select a socket and execute task +async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue): sock = await sockets.get() try: requestBody = json.dumps(request) diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py index 0a641d423a58f..0a6721153101d 100644 --- a/vllm/entrypoints/openai/connect_worker.py +++ b/vllm/entrypoints/openai/connect_worker.py @@ -1,12 +1,13 @@ -import json -from typing import Optional + import zmq import zmq.asyncio import tempfile import uuid import httpx import json +import traceback +from typing import Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -22,7 +23,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.logger import init_logger -import traceback prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -54,7 +54,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers: return httpx.Headers(headers_dict) async def worker_routine(worker_url: str, app: FastAPI, - context: zmq.asyncio.Context = None, i: int = 0): + context: zmq.asyncio.Context, i: int = 0): """Worker routine""" try: # Socket to talk to dispatcher @@ -65,46 +65,46 @@ async def worker_routine(worker_url: str, app: FastAPI, logger.info(f"{worker_identity} started at {worker_url}") while True: identity, url, header, body = await socket.recv_multipart() - logger.info(f"worker-{i} Received request identity: [{identity} ]") - url = url.decode() - logger.info(f"worker-{i} Received request url: [{url} ]") - header = bytes_to_headers(header) - logger.info(f"worker-{i} Received request headers: [{header} ]") - body = json.loads(body.decode()) - logger.info(f"worker-{i} Received request body: [{body} ]") + logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]") + url_str = url.decode() + logger.info(f"worker-{i} Received request url: [{url_str} ]") + headers = bytes_to_headers(header) + logger.info(f"worker-{i} Received request headers: [{headers} ]") + body_json = json.loads(body.decode()) + logger.info(f"worker-{i} Received request body: [{body_json} ]") logger.info(f"worker-{i} Calling OpenAI API") - completionRequest = CompletionRequest(**body) - createRequest = create_request(url, "POST", body, header) + completionRequest = CompletionRequest(**body_json) + createRequest = create_request(url_str, "POST", body_json, headers) generator = await create_completion(app, completionRequest, createRequest) logger.info(f"worker-{i} Received response: [{generator} ]") if isinstance(generator, ErrorResponse): content = generator.model_dump_json() - context = json.loads(content) - context.append("status_code", generator.code) - await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()]) + context_json = json.loads(content) + context_json.append("status_code", generator.code) + await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')]) elif isinstance(generator, CompletionResponse): - await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())]) + await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')]) else: async for chunk in generator: logger.info(f"worker-{i} Sending response chunk: [{chunk} ]") - await socket.send_multipart([identity, b"text/event-stream", chunk.encode()]) + await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')]) except Exception as e: logger.error(f"Error in worker routine: {e} worker-{i}") logger.error(traceback.format_exc()) async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request): handler = completion(app) - logger.info(f"zmq requset post: {request}") + logger.info(f"zmq request post: {request}") if handler is None: return base(app).create_error_response( message="The model does not support Completions API") generator = await handler.create_completion(request, raw_request) - logger.info(f"zmq requset end post: {generator}") + logger.info(f"zmq request end post: {generator}") return generator -def create_request(path: str, method: str, body: bytes, headers: dict = None): +def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request: scope = { 'type': 'http', 'http_version': '1.1', @@ -113,7 +113,7 @@ def create_request(path: str, method: str, body: bytes, headers: dict = None): 'headers': list(headers.items()) if headers else [], } if body: - scope['body'] = json.dumps(body).encode('utf-8') + scope['body'] = json.dumps(body) async def receive(): return { 'type': 'http.request',