From 2c31e4c3ea5962c1db33623adea0d5f6b923e836 Mon Sep 17 00:00:00 2001 From: clark Date: Thu, 9 Jan 2025 00:12:58 +0800 Subject: [PATCH] Run yapf and ruff Signed-off-by: clark --- .../disagg_benchmarks/zmq/test_request.py | 49 ++++++----- vllm/entrypoints/disagg_connector.py | 85 ++++++++++++------- vllm/entrypoints/launcher.py | 25 +++--- vllm/entrypoints/openai/api_server.py | 3 +- vllm/entrypoints/openai/connect_worker.py | 69 ++++++++------- 5 files changed, 140 insertions(+), 91 deletions(-) diff --git a/benchmarks/disagg_benchmarks/zmq/test_request.py b/benchmarks/disagg_benchmarks/zmq/test_request.py index 57cb1e548abe2..87c3b34d0181e 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_request.py +++ b/benchmarks/disagg_benchmarks/zmq/test_request.py @@ -1,31 +1,34 @@ import asyncio import json + import aiohttp -# test connect completions we assume prefill and decode are on the same node -# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 --chat-template ~/vllm/examples/template_chatglm2.jinja + +# test connect completions we assume prefill and decode are on the same node +# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \ +# --chat-template ~/vllm/examples/template_chatglm2.jinja # 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010 # 3. python test_request.py - async def test_connect_completions(session): try: base_url = "http://localhost:8001/v1/connect/completions" body = { - "temperature": 0.5, - "top_p": 0.9, - "max_tokens": 150, - "frequency_penalty": 1.3, - "presence_penalty": 0.2, - "repetition_penalty": 1.2, - "model": "facebook/opt-125m", - "prompt": "Can you introduce vllm?", - "stream": True, - "stream_options": { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 150, + "frequency_penalty": 1.3, + "presence_penalty": 0.2, + "repetition_penalty": 1.2, + "model": "facebook/opt-125m", + "prompt": "Can you introduce vllm?", + "stream": True, + "stream_options": { "include_usage": True - }} - print(f"Sending request to {base_url}, body {body}") - async with session.post(base_url, json= body) as response: - + } + } + print(f"Sending request to {base_url}, body {body}") + async with session.post(base_url, json=body) as response: + print(response.status) print(response.headers) responseText = "" @@ -40,13 +43,18 @@ async def test_connect_completions(session): print(f"Error decoding chunk: {chunk!r}") else: # Print the headers and JSON response - print(f"Unexpected Transfer-Encoding: {transfer_encoding} {response.headers} {await response.json()}") + print("Unexpected Transfer-Encoding: {} {} {}".format( + transfer_encoding, response.headers, await + response.json())) else: print(f"Request failed with status code {response.status}") - print(f"baseurl {base_url} response data {extract_data(responseText)}") + print( + f"baseurl {base_url} response data {extract_data(responseText)}" + ) except aiohttp.ClientError as e: print(f"Error: {e}") + def extract_data(responseText): reply = "" for data in responseText.split("\n\n"): @@ -66,7 +74,7 @@ def extract_data(responseText): return reply - + async def main(): async with aiohttp.ClientSession() as session: tasks = [] @@ -76,4 +84,3 @@ async def main(): asyncio.run(main()) - diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 2b2749680f625..48c8b2b2abc66 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -1,15 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 + import json +import signal +import uuid +# from fastapi.lifespan import Lifespan +from asyncio import Queue +from contextlib import asynccontextmanager + import uvicorn import zmq import zmq.asyncio from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse -from contextlib import asynccontextmanager -# from fastapi.lifespan import Lifespan -from asyncio import Queue -import uuid -import signal + from vllm.logger import init_logger # default prefill and decode url @@ -21,55 +24,69 @@ socket_decode_num = 5 # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.connect') + @asynccontextmanager async def lifespan(app: FastAPI): # create socket pool with prefill and decode logger.info("start create_socket_pool") app.state.zmqctx = zmq.asyncio.Context() - app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) + app.state.sockets_prefill = await create_socket_pool( + app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) logger.info("success create_socket_pool sockets_prefill") - app.state.sockets_decode = await create_socket_pool(app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx) + app.state.sockets_decode = await create_socket_pool( + app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx) logger.info("success create_socket_pool sockets_decode") yield ## close zmq context logger.info("term zmqctx") app.state.zmqctx.destroy(linger=0) + app = FastAPI(lifespan=lifespan) + # create async socket pool with num_sockets use ZMQ_DEALER -async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue: - sockets = Queue() +async def create_socket_pool(url: str, num_sockets: int, + zmqctx: zmq.asyncio.Context) -> Queue: + sockets: Queue = Queue() for i in range(num_sockets): sock = zmqctx.socket(zmq.DEALER) identity = f"worker-{i}-{uuid.uuid4()}" sock.setsockopt(zmq.IDENTITY, identity.encode()) sock.connect(url) - logger.info(f"{identity} started at {url} {sockets.qsize()}") + logger.info("%s started at %s with queue size %s", identity, url, + sockets.qsize()) await sockets.put(sock) return sockets + # select a socket and execute task -async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue): +async def execute_task_async(route: str, headers: dict, request: dict, + sockets: Queue): sock = await sockets.get() try: requestBody = json.dumps(request) headersJson = json.dumps(headers) - logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}") - await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()]) - logger.info(f"Sent end") + logger.info("Sending requestBody: %s to %s with headers: %s", + requestBody, route, headersJson) + await sock.send_multipart( + [route.encode(), + headersJson.encode(), + requestBody.encode()]) + logger.info("Sent end") while True: - logger.info(f"Waiting for reply") + logger.info("Waiting for reply") [contentType, reply] = await sock.recv_multipart() - logger.info(f"Received result: {contentType}, {reply}") + logger.info("Received result: %s, %s", contentType, reply) reply = reply.decode() yield f"{reply}" if "[DONE]" in reply: - logger.info(f"Received stop signal, return socket") + logger.info("Received stop signal, return socket") break finally: await sockets.put(sock) + @app.post('/v1/connect/completions') async def chat_completions(request: Request): try: @@ -77,21 +94,26 @@ async def chat_completions(request: Request): x_request_id = str(uuid.uuid4()) header = dict(request.headers) if header.get("X-Request-Id") is None: - logger.info(f"add X-Request-Id: {x_request_id}") + logger.info("add X-Request-Id: %s", x_request_id) header["X-Request-Id"] = x_request_id original_request_data = await request.json() - logger.info(f"Received request: {original_request_data} header: {header}") + logger.info("Received request: %s header: %s", original_request_data, + header) prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request['max_tokens'] = 1 route = "/v1/completions" # finish prefill - async for _ in execute_task_async(route, header, prefill_request, app.state.sockets_prefill): + async for _ in execute_task_async(route, header, prefill_request, + app.state.sockets_prefill): continue # return decode - return StreamingResponse(execute_task_async(route, header,original_request_data, app.state.sockets_decode), media_type="text/event-stream") - + return StreamingResponse(execute_task_async(route, header, + original_request_data, + app.state.sockets_decode), + media_type="text/event-stream") + except Exception as e: import sys import traceback @@ -99,16 +121,20 @@ async def chat_completions(request: Request): logger.error("Error occurred in disagg prefill proxy server") logger.error(e) logger.error("".join(traceback.format_exception(*exc_info))) - + async def run_disagg_connector(args, **uvicorn_kwargs) -> None: - logger.info(f"vLLM Disaggregate Connector start {args} {uvicorn_kwargs}") + logger.info("vLLM Disaggregate Connector start %s %s", args, + uvicorn_kwargs) logger.info(args.prefill_addr) - app.state.prefill_addr = f"tcp://{args.prefill_addr}" if args.prefill_addr is not None else url_prefill - app.state.decode_addr = f"tcp://{args.decode_addr}" if args.decode_addr is not None else url_decode - logger.info(f"start connect url_prefill: {app.state.prefill_addr} url_decode: {app.state.decode_addr}") - + app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr + is not None else url_prefill) + app.state.decode_addr = (f"tcp://{args.decode_addr}" + if args.decode_addr is not None else url_decode) + logger.info("start connect url_prefill: %s url_decode: %s", + app.state.prefill_addr, app.state.decode_addr) + def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing raise KeyboardInterrupt("terminated") @@ -119,8 +145,7 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None: server = uvicorn.Server(config) await server.serve() - if __name__ == "__main__": # url = 'tcp://127.0.0.1:5555' - uvicorn.run(app, host="0.0.0.0", port=8001) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3872346853da5..58a0b749e28dd 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -6,17 +6,16 @@ import socket from http import HTTPStatus from typing import Any, Optional +import uvicorn import zmq import zmq.asyncio - -import uvicorn from fastapi import FastAPI, Request, Response -from vllm.entrypoints.openai.connect_worker import worker_routine from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError from vllm.entrypoints.ssl import SSLCertRefresher +from vllm.entrypoints.openai.connect_worker import worker_routine from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -76,11 +75,13 @@ async def serve_http(app: FastAPI, "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) logger.info("Shutting down FastAPI HTTP server.") - return server.shutdown() + return server.shutdown() + async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: """Server routine""" - logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}") + logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, + zmq_server_port) url_worker = "inproc://workers" url_client = f"tcp://0.0.0.0:{zmq_server_port}" # Prepare our context and sockets @@ -89,15 +90,18 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: # Socket to talk to clients clients = context.socket(zmq.ROUTER) clients.bind(url_client) - logger.info(f"ZMQ Server ROUTER started at {url_client}") + logger.info("ZMQ Server ROUTER started at %s", url_client) # Socket to talk to workers workers = context.socket(zmq.DEALER) workers.bind(url_worker) - logger.info(f"ZMQ Worker DEALER started at {url_worker}") + logger.info("ZMQ Worker DEALER started at %s", url_worker) + + tasks = [ + asyncio.create_task(worker_routine(url_worker, app, context, i)) + for i in range(5) + ] + proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) - tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)] - proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) - try: await asyncio.gather(*tasks, proxy_task) except KeyboardInterrupt: @@ -110,6 +114,7 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: workers.close() context.destroy(linger=0) + def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2c67cc1644eb9..86004feb7fee5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1031,7 +1031,8 @@ async def run_server(args, **uvicorn_kwargs) -> None: zmq_server_port = args.zmq_server_port if zmq_server_port is not None: - logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port) + logger.info("asyncio.create_task Starting ZMQ server at port %d", + zmq_server_port) asyncio.create_task(serve_zmq(args, zmq_server_port, app)) shutdown_task = await serve_http( diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py index 0a6721153101d..b5665bd9e742d 100644 --- a/vllm/entrypoints/openai/connect_worker.py +++ b/vllm/entrypoints/openai/connect_worker.py @@ -1,27 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +import json +import tempfile +import traceback +import uuid +from typing import Optional + +import httpx import zmq import zmq.asyncio -import tempfile -import uuid -import httpx -import json -import traceback - -from typing import Optional from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionRequest, - CompletionRequest, CompletionResponse, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) from vllm.logger import init_logger prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -62,49 +62,61 @@ async def worker_routine(worker_url: str, app: FastAPI, worker_identity = f"worker-{i}-{uuid.uuid4()}" socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) socket.connect(worker_url) - logger.info(f"{worker_identity} started at {worker_url}") + logger.info("%s started at %s", worker_identity, worker_url) while True: identity, url, header, body = await socket.recv_multipart() - logger.info(f"worker-{i} Received request identity: [{identity.decode()} ]") + logger.info("worker-%d Received request identity: [ %s ]", + i, identity.decode()) url_str = url.decode() - logger.info(f"worker-{i} Received request url: [{url_str} ]") + logger.info("worker-%d Received request url: [ %s ]", + i, url_str) headers = bytes_to_headers(header) - logger.info(f"worker-{i} Received request headers: [{headers} ]") + logger.info("worker-%d Received request headers: [ %s ]", + i, headers) body_json = json.loads(body.decode()) - logger.info(f"worker-{i} Received request body: [{body_json} ]") - logger.info(f"worker-{i} Calling OpenAI API") + logger.info("worker-%d Received request body: [ %s ]", + i, body_json) + logger.info("worker-%d Calling OpenAI API", i) completionRequest = CompletionRequest(**body_json) createRequest = create_request(url_str, "POST", body_json, headers) - generator = await create_completion(app, completionRequest, createRequest) - logger.info(f"worker-{i} Received response: [{generator} ]") + generator = await create_completion(app, completionRequest, + createRequest) + logger.info("worker-%d Received response: [ %s ]", i, generator) if isinstance(generator, ErrorResponse): content = generator.model_dump_json() context_json = json.loads(content) context_json.append("status_code", generator.code) - await socket.send_multipart([identity, b"application/json", json.dumps(context_json).encode('utf-8')]) + await socket.send_multipart([identity, b"application/json", + json.dumps(context_json).encode('utf-8')]) elif isinstance(generator, CompletionResponse): - await socket.send_multipart([identity, b"application/json", json.dumps(generator.model_dump()).encode('utf-8')]) + await socket.send_multipart([identity, b"application/json", + json.dumps(generator.model_dump()).encode('utf-8')]) else: async for chunk in generator: - logger.info(f"worker-{i} Sending response chunk: [{chunk} ]") - await socket.send_multipart([identity, b"text/event-stream", chunk.encode('utf-8')]) + logger.info("worker-%d Sending response chunk: [ %s ]", + i, chunk) + await socket.send_multipart([identity, + b"text/event-stream", + chunk.encode('utf-8')]) except Exception as e: - logger.error(f"Error in worker routine: {e} worker-{i}") + logger.error("Error in worker routine: %s worker-%d", e, i) logger.error(traceback.format_exc()) -async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request): +async def create_completion(app: FastAPI, request: CompletionRequest, + raw_request: Request): handler = completion(app) - logger.info(f"zmq request post: {request}") + logger.info("zmq request post: %s", request) if handler is None: return base(app).create_error_response( message="The model does not support Completions API") generator = await handler.create_completion(request, raw_request) - logger.info(f"zmq request end post: {generator}") + logger.info("zmq request end post: %s", generator) return generator -def create_request(path: str, method: str, body: dict, headers: httpx.Headers) -> Request: +def create_request(path: str, method: str, body: dict, + headers: httpx.Headers) -> Request: scope = { 'type': 'http', 'http_version': '1.1', @@ -120,10 +132,9 @@ def create_request(path: str, method: str, body: dict, headers: httpx.Headers) - 'body': scope.get('body', b''), } async def send(message): - pass + pass return Request(scope, receive=receive, send=send) if __name__ == "__main__": print(bytes_to_headers(b'{"Content-Type": "application/json"}')) -