From 912031ceb58d35dee3a1c0831d36acb02589a927 Mon Sep 17 00:00:00 2001 From: clark Date: Sat, 8 Mar 2025 20:51:39 +0800 Subject: [PATCH] refactor disagg Signed-off-by: clark --- .../online_serving/disaggregated_prefill.sh | 25 +- .../disaggregated_prefill_zmq.sh | 113 +++++++ vllm/entrypoints/cli/connect.py | 72 ++++ vllm/entrypoints/cli/disagg.py | 115 +++++++ vllm/entrypoints/cli/main.py | 4 + vllm/entrypoints/disagg_connector.py | 309 +++++++++++------- vllm/entrypoints/launcher.py | 41 --- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/connect_worker.py | 139 -------- vllm/entrypoints/openai/zmq_server.py | 211 ++++++++++++ 10 files changed, 713 insertions(+), 318 deletions(-) create mode 100644 examples/online_serving/disaggregated_prefill_zmq.sh create mode 100644 vllm/entrypoints/cli/connect.py create mode 100644 vllm/entrypoints/cli/disagg.py delete mode 100644 vllm/entrypoints/openai/connect_worker.py create mode 100644 vllm/entrypoints/openai/zmq_server.py diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index ef42854f362d1..bd0e2d44dbfb2 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -26,6 +26,14 @@ cleanup() { export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + # a function that waits vLLM server to start wait_for_server() { local port=$1 @@ -41,7 +49,6 @@ wait_for_server() { # prefilling instance, which is the KV producer CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ --port 8100 \ - --zmq-server-port 7010 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ @@ -51,25 +58,13 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \ # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \ --port 8200 \ - --zmq-server-port 7011 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --trust-remote-code \ --kv-transfer-config \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & -# launch a proxy server that opens the service at port 8000 -# the workflow of this proxy: -# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens -# to 1 -# - after the prefill vLLM finishes prefill, send the request to decode vLLM -# instance (via zmq port 7011) -vllm connect --port 8000 \ - --prefill-addr 127.0.0.1:7010 \ - --decode-addr 127.0.0.1:7011 & - -# wait until prefill, decode instances and proxy are ready -wait_for_server 8000 +# wait until prefill and decode instances are ready wait_for_server 8100 wait_for_server 8200 @@ -118,4 +113,4 @@ echo "Output of first request: $output1" echo "Output of second request: $output2" echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" -echo "" +echo "" \ No newline at end of file diff --git a/examples/online_serving/disaggregated_prefill_zmq.sh b/examples/online_serving/disaggregated_prefill_zmq.sh new file mode 100644 index 0000000000000..81301e8b53386 --- /dev/null +++ b/examples/online_serving/disaggregated_prefill_zmq.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling with ZMQ +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +set -xe + +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# a function that waits vLLM connect to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +# a function that waits vLLM disagg to start +wait_for_disagg_server() { + local log_file=$1 + timeout 1200 bash -c " + until grep -q 'zmq Server started at' $log_file; do + sleep 1 + done" && return 0 || return 1 +} + + +# You can also adjust --kv-ip and --kv-port for distributed inference. + +# prefilling instance, which is the KV producer +CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \ + --zmq-server-addr testipc0 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 & + +# decoding instance, which is the KV consumer +CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \ + --zmq-server-addr testipc1 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > vllm_disagg_decode.log 2>&1 & + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance (via zmq addr testipc1) +vllm connect --port 8000 \ + --prefill-addr testipc0 \ + --decode-addr testipc1 & + +# wait until prefill, decode instances and proxy are ready +wait_for_server 8000 +wait_for_disagg_server vllm_disagg_prefill.log +wait_for_disagg_server vllm_disagg_decode.log + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + + +# Cleanup commands +pgrep python | xargs kill -9 +pkill -f python + +echo "" + +sleep 1 + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" +echo "" diff --git a/vllm/entrypoints/cli/connect.py b/vllm/entrypoints/cli/connect.py new file mode 100644 index 0000000000000..467f6a518d033 --- /dev/null +++ b/vllm/entrypoints/cli/connect.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse + +import uvloop + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.disagg_connector import run_disagg_connector +from vllm.utils import FlexibleArgumentParser + + +class ConnectSubcommand(CLISubcommand): + """The `connect` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "connect" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + uvloop.run(run_disagg_connector(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_connect_parsed_args(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + connect_parser = subparsers.add_parser( + "connect", + help= + "Start the vLLM OpenAI Compatible API Server which connect to other" + "servers disaggreate prefill and decode", + usage="vllm connect [options]") + + return make_connect_arg_parser(connect_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [ConnectSubcommand()] + + +def make_connect_arg_parser( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--port", + type=int, + default=7001, + help="The fastapi server port default 7001") + # support ipc only now, support tcp later(with auth) + parser.add_argument( + "--protocol", + type=str, + choices=["ipc"], + default="ipc", + help="The zmq socket addr protocol IPC (Inter-Process Communication)") + # security concern only support ipc now + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The zmq ipc prefill address") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The zmq ipc decode address") + + return parser + + +def validate_connect_parsed_args(args: argparse.Namespace): + """Quick checks for connect args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "connect": + return diff --git a/vllm/entrypoints/cli/disagg.py b/vllm/entrypoints/cli/disagg.py new file mode 100644 index 0000000000000..5d286e8e0b7aa --- /dev/null +++ b/vllm/entrypoints/cli/disagg.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import asyncio + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.cli_args import (LoRAParserAction, + PromptAdapterParserAction) +from vllm.entrypoints.openai.zmq_server import run_zmq_server +from vllm.utils import FlexibleArgumentParser + + +class DisaggSubcommand(CLISubcommand): + """The `disagg` subcommand for the vLLM CLI. """ + + def __init__(self): + self.name = "disagg" + super().__init__() + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + # The default value of `--model` + if not args.model_tag: + raise ValueError( + "With `vllm disagg`, you should provide the model as a " + "positional argument instead of via the `--model` option.") + + # EngineArgs expects the model name to be passed as --model. + args.model = args.model_tag + + asyncio.run(run_zmq_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_parsed_disagg_args(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + disagg_parser = subparsers.add_parser( + "disagg", + help="Start the vLLM OpenAI Compatible API zmq server", + usage="vllm disagg [options]") + + return make_disagg_arg_parser(disagg_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [DisaggSubcommand()] + + +def make_disagg_arg_parser( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "model_tag", + type=str, + help= + "The model tag to use for the vLLM OpenAI Compatible API zmq server.") + parser.add_argument('--zmq-server-addr', + type=str, + required=True, + help='The address to serve the zmq server on.') + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When ``--max-logprobs`` is specified, represents single tokens " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.") + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): ``'name=path'`` " + "Example (new format): " + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") + + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + + AsyncEngineArgs.add_cli_args(parser) + + return parser + + +def validate_parsed_disagg_args(args: argparse.Namespace): + """Quick checks for model disagg args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "disagg": + return + + # Enable reasoning needs a reasoning parser to be valid + if args.enable_reasoning and not args.reasoning_parser: + raise TypeError("Error: --enable-reasoning requires " + "--reasoning-parser") diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 13f2761b0db06..e9edbf663ee8e 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -6,6 +6,8 @@ import signal import sys import vllm.entrypoints.cli.benchmark.main +import vllm.entrypoints.cli.connect +import vllm.entrypoints.cli.disagg import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version @@ -18,6 +20,8 @@ CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, vllm.entrypoints.cli.benchmark.main, + vllm.entrypoints.cli.disagg, + vllm.entrypoints.cli.connect, ] diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 3eb976d3fbc99..cb6549dd40bc2 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -3,44 +3,69 @@ import asyncio import json import signal +import sys import traceback import uuid -from asyncio import Queue +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import AsyncGenerator +from typing import Union import uvicorn import uvloop import zmq import zmq.asyncio -from fastapi import FastAPI, Request +from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse +from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR, + CONTENT_TYPE_JSON, + CONTENT_TYPE_STREAM) from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser -# default prefill and decode addr -time_out = 180 -socket_prefill_num = 100 -socket_decode_num = 100 -context_type_json = "application/json" -context_type_error = "error" - # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.disagg_connector') +TIME_OUT = 5 +X_REQUEST_ID_KEY = "X-Request-Id" + +# communication between output handlers and execute_task_async +request_queues: dict[str, asyncio.Queue] + + +async def log_stats(request_queues: dict[str, asyncio.Queue]): + while True: + logger.info("Running requests: %d", len(request_queues)) + await asyncio.sleep(10) + + +# create async socket use ZMQ_DEALER +async def create_socket(url: str, + zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket: + sock = zmqctx.socket(zmq.DEALER) + identity = f"connector-{uuid.uuid4()}" + sock.setsockopt(zmq.IDENTITY, identity.encode()) + sock.connect(url) + logger.info("%s started at %s", identity, url) + return sock + @asynccontextmanager async def lifespan(app: FastAPI): # create socket pool with prefill and decode - logger.info("start create_socket_pool") + logger.info("start connect zmq server") app.state.zmqctx = zmq.asyncio.Context() - app.state.sockets_prefill = await create_socket_pool( - app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) - logger.info("success create_socket_pool sockets_prefill") - app.state.sockets_decode = await create_socket_pool( - app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx) - logger.info("success create_socket_pool sockets_decode") + app.state.prefill_socket = await create_socket(app.state.prefill_addr, + zmqctx=app.state.zmqctx) + logger.info("success create_socke sockets_prefill") + app.state.decode_socket = await create_socket(app.state.decode_addr, + zmqctx=app.state.zmqctx) + logger.info("success create_socket sockets_decode") + global request_queues + request_queues = {} + asyncio.create_task(prefill_handler(app.state.prefill_socket)) + asyncio.create_task(decode_handler(app.state.decode_socket)) + asyncio.create_task(log_stats(request_queues)) yield ## close zmq context logger.info("shutdown disagg connector") @@ -51,59 +76,140 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) -# create async socket pool with num_sockets use ZMQ_DEALER -async def create_socket_pool(url: str, num_sockets: int, - zmqctx: zmq.asyncio.Context) -> Queue: - sockets: Queue[zmq.Socket] = Queue() - for i in range(num_sockets): - sock = zmqctx.socket(zmq.DEALER) - identity = f"worker-{i}-{uuid.uuid4()}" - sock.setsockopt(zmq.IDENTITY, identity.encode()) - sock.connect(url) - logger.info("%s started at %s with queue size %s", identity, url, - sockets.qsize()) - await sockets.put(sock) - return sockets +@app.post('/v1/completions') +async def completions(request: Request, background_tasks: BackgroundTasks): + try: + # Add the X-Request-Id header to the raw headers list + header = dict(request.headers) + request_id = header.get(X_REQUEST_ID_KEY) + queue = asyncio.Queue() + if request_id is None: + request_id = str(uuid.uuid4()) + logger.debug("add X-Request-Id: %s", request_id) + header[X_REQUEST_ID_KEY] = request_id + logger.debug("X-Request-Id is: %s", request_id) + request_queues[request_id] = queue + request_data = await request.json() + logger.info("Received request_id: %s, request: %s, header: %s", + request_id, request_data, header) + original_max_tokens = request_data['max_tokens'] + # change max_tokens = 1 to let it only do prefill + request_data['max_tokens'] = 1 + # finish prefill + try: + prefill_response = await prefill(header, request_data) + if isinstance(prefill_response, JSONResponse): + return prefill_response + logger.debug("finish prefill start decode") + request_data['max_tokens'] = original_max_tokens + response = await decode(header, request_data) + logger.debug("finish decode") + except Exception as e: + logger.error("Error occurred in disagg prefill proxy server, %s", + e) + response = JSONResponse({"error": { + "message": str(e) + }}, + status_code=500) + return response + + except Exception as e: + exc_info = sys.exc_info() + logger.error("Error occurred in disagg prefill proxy server") + logger.error(e) + logger.error("".join(traceback.format_exception(*exc_info))) + response = JSONResponse({"error": { + "message": str(e) + }}, + status_code=500) + return response + finally: + background_tasks.add_task(cleanup_request_id, request_id) + + +async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str): + while True: + try: + [request_id, contentType, reply] = await socket.recv_multipart() + contentType_str = contentType.decode() + reply_str = reply.decode() + request_id_str = request_id.decode() + logger.debug( + "%s socket received result contentType: %s, " + "request_id: %s, reply: %s", scene, contentType_str, + request_id_str, reply_str) + if request_id_str in request_queues: + request_queues[request_id_str].put_nowait( + (contentType_str, reply_str)) + if "[DONE]" in reply_str: + logger.debug( + "%s socket received stop signal request_id: %s", scene, + request_id_str) + request_queues[request_id_str].put_nowait( + (contentType_str, None)) + else: + logger.debug( + "%s socket received but request_id not found discard: %s", + scene, request_id_str) + except Exception as e: + logger.error(traceback.format_exc()) + logger.error("%s handler error: %s", scene, e) + + +# prefill handler +async def prefill_handler(prefill_socket: zmq.asyncio.Socket): + await socket_recv_handler(prefill_socket, "prefill") + + +# decode handler +async def decode_handler(decode_socket: zmq.asyncio.Socket): + await socket_recv_handler(decode_socket, "decode") # select a socket and execute task -async def execute_task_async(route: str, headers: dict, request: dict, - sockets: Queue): - sock: zmq.Socket = await sockets.get() +async def execute_task_async(headers: dict, request: dict, + socket: zmq.asyncio.Socket): try: + request_id = headers.get(X_REQUEST_ID_KEY) requestBody = json.dumps(request) - headersJson = json.dumps(headers) - logger.info("Sending requestBody: %s to %s with headers: %s", - requestBody, route, headersJson) - await asyncio.wait_for(sock.send_multipart( - [route.encode(), - headersJson.encode(), - requestBody.encode()]), - timeout=time_out) + logger.info("Sending requestBody: %s", requestBody) + socket.send_multipart([request_id.encode(), requestBody.encode()]) logger.debug("Sent end") + queue = request_queues[request_id] while True: logger.debug("Waiting for reply") - [contentType, - reply] = await asyncio.wait_for(sock.recv_multipart(), - timeout=time_out) - contentType_str = contentType.decode() - reply_str = reply.decode() - logger.debug("Received result: %s, %s", contentType_str, reply_str) - yield (contentType_str, reply_str) - if context_type_json == contentType_str: - logger.debug("Received %s message, return socket", - contentType_str) + (contentType, + reply) = await asyncio.wait_for(queue.get(), TIME_OUT) + logger.debug("Received result: %s, %s", contentType, reply) + if reply is None: + logger.debug("Received stop signal, request_id: %s", + request_id) break - if "[DONE]" in reply_str: - logger.debug("Received stop signal, return socket") + yield (contentType, reply) + if contentType == CONTENT_TYPE_JSON: + logger.debug("Received %s message, request_id: %s", + contentType, request_id) break + except asyncio.TimeoutError: logger.error(traceback.format_exc()) - logger.error("Timeout, return socket: %s", - sock.getsockopt(zmq.IDENTITY)) - yield (context_type_error, "System Error") + yield (CONTENT_TYPE_ERROR, "System Error") finally: - await sockets.put(sock) + logger.debug("request_id: %s, execute_task_async end", request_id) + + +async def prefill(header: dict, + original_request_data: dict) -> Union[JSONResponse, bool]: + logger.debug("start prefill") + generator = execute_task_async(header, original_request_data, + app.state.prefill_socket) + async for contentType, reply in generator: + logger.debug("contentType: %s, reply: %s", contentType, reply) + if contentType == CONTENT_TYPE_ERROR: + response = JSONResponse({"error": reply}) + response.status_code = 500 + return response + return True async def generate_stream_response(fisrt_reply: str, @@ -113,78 +219,34 @@ async def generate_stream_response(fisrt_reply: str, yield reply -async def prefill(route: str, header: dict, original_request_data: dict): - logger.info("start prefill") - generator = execute_task_async(route, header, original_request_data, - app.state.sockets_prefill) - async for contentType, reply in generator: - logger.debug("contentType: %s, reply: %s", contentType, reply) - if context_type_error == contentType: - response = JSONResponse({"error": reply}) - response.status_code = 500 - return response - return True - - -async def decode(route: str, header: dict, original_request_data: dict): +async def decode( + header: dict, + original_request_data: dict) -> Union[JSONResponse, StreamingResponse]: logger.info("start decode") - generator = execute_task_async(route, header, original_request_data, - app.state.sockets_decode) + generator = execute_task_async(header, original_request_data, + app.state.decode_socket) async for contentType, reply in generator: logger.debug("contentType: %s, reply: %s", contentType, reply) - if context_type_error == contentType: + if contentType == CONTENT_TYPE_ERROR: response = JSONResponse({"error": reply}) response.status_code = 500 return response - elif context_type_json == contentType: + elif contentType == CONTENT_TYPE_JSON: return JSONResponse(reply) else: return StreamingResponse(generate_stream_response( reply, generator), - media_type="text/event-stream") + media_type=CONTENT_TYPE_STREAM) -@app.post('/v1/completions') -async def chat_completions(request: Request): - try: - # Add the X-Request-Id header to the raw headers list - x_request_id = str(uuid.uuid4()) - header = dict(request.headers) - if header.get("X-Request-Id") is None: - logger.info("add X-Request-Id: %s", x_request_id) - header["X-Request-Id"] = x_request_id - request_data = await request.json() - logger.info("Received request: %s header: %s", request_data, header) - original_max_tokens = request_data['max_tokens'] - # change max_tokens = 1 to let it only do prefill - request_data['max_tokens'] = 1 - route = "/v1/completions" - # finish prefill - try: - prefill_response = await prefill(route, header, request_data) - if isinstance(prefill_response, JSONResponse): - return prefill_response - logger.info("finish prefill start decode") - request_data['max_tokens'] = original_max_tokens - response = await decode(route, header, request_data) - logger.info("finish decode") - except Exception as e: - logger.error("Error occurred in disagg prefill proxy server, %s", - e) - response = JSONResponse({"error": {"message": str(e)}}) - return response - - except Exception as e: - import sys - import traceback - exc_info = sys.exc_info() - logger.error("Error occurred in disagg prefill proxy server") - logger.error(e) - logger.error("".join(traceback.format_exception(*exc_info))) +def cleanup_request_id(request_id: str): + if request_id in request_queues: + logger.info("del request_id: %s, decode finished", request_id) + del request_queues[request_id] -async def run_disagg_connector(args, **uvicorn_kwargs) -> None: +async def run_disagg_connector(args, **uvicorn_kwargs): logger.info("vLLM Disaggregate Connector start %s %s", args, uvicorn_kwargs) logger.info(args.prefill_addr) @@ -192,8 +254,9 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None: app.state.prefill_addr = f"ipc://{args.prefill_addr}" app.state.decode_addr = f"ipc://{args.decode_addr}" logger.info( - "start connect prefill_addr: %s decode_addr: %s zmq server port: %s", - app.state.prefill_addr, app.state.decode_addr, app.state.port) + "start connect prefill_addr: %s decode_addr: %s " + "zmq server fastapi port: %s", app.state.prefill_addr, + app.state.decode_addr, app.state.port) def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing @@ -208,20 +271,22 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None: if __name__ == "__main__": # NOTE(simon): - # This section should be in sync with vllm/scripts.py for CLI entrypoints. - parser = FlexibleArgumentParser(description="vLLM disagg zmq server.") + # This section should be sync with vllm/entrypoints/cli/connect.py for CLI + # entrypoints. + parser = FlexibleArgumentParser(description="vLLM disagg connect server.") parser.add_argument("--port", type=int, - default=8000, - help="The fastapi server port") + default=8001, + help="The fastapi server port default 8001") + # security concern only support ipc now parser.add_argument("--prefill-addr", type=str, required=True, - help="The prefill address IP:PORT") + help="The zmq ipc prefill address") parser.add_argument("--decode-addr", type=str, required=True, - help="The decode address IP:PORT") + help="The zmq ipc decode address") args = parser.parse_args() diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index e9c865b953401..b09ee526f14ae 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -7,16 +7,12 @@ from http import HTTPStatus from typing import Any, Optional import uvicorn -import zmq -import zmq.asyncio -import zmq.devices from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError from vllm.entrypoints.ssl import SSLCertRefresher -from vllm.entrypoints.openai.connect_worker import worker_routine from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -79,43 +75,6 @@ async def serve_http(app: FastAPI, return server.shutdown() -async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: - """Server routine""" - logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, - zmq_server_port) - # different zmq context can't communicate use inproc - workers_addr = "ipc://workers" - clients_addr = f"ipc://127.0.0.1:{zmq_server_port}" - # Prepare our context and sockets - context = zmq.asyncio.Context.instance() - try: - tasks = [ - asyncio.create_task(worker_routine(workers_addr, app, context, i)) - for i in range(100) - ] - logger.info("zmq tasks: %s", tasks) - # thread safety proxy create socket in the background: - # https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices - thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER) - # unlimited HWM - hwm_limit = 0 - thread_proxy.bind_in(clients_addr) - thread_proxy.setsockopt_in(zmq.SNDHWM, hwm_limit) - thread_proxy.setsockopt_in(zmq.RCVHWM, hwm_limit) - thread_proxy.bind_out(workers_addr) - thread_proxy.setsockopt_out(zmq.SNDHWM, hwm_limit) - thread_proxy.setsockopt_out(zmq.RCVHWM, hwm_limit) - thread_proxy.start() - await asyncio.gather(*tasks) - except KeyboardInterrupt: - print("ZMQ Server interrupted") - except zmq.ZMQError as e: - print("ZMQError:", e) - finally: - # We never get here but clean up anyhow - context.destroy(linger=0) - - def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 86004feb7fee5..3040793eb005a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.launcher import serve_http, serve_zmq +from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py deleted file mode 100644 index c9c42df8fb613..0000000000000 --- a/vllm/entrypoints/openai/connect_worker.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import json -import tempfile -import traceback -import uuid -from typing import Optional - -import httpx -import zmq -import zmq.asyncio -from fastapi import FastAPI, Request - -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (CompletionRequest, - CompletionResponse, - ErrorResponse) -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.serving_tokenization import ( - OpenAIServingTokenization) -from vllm.logger import init_logger - -prometheus_multiproc_dir: tempfile.TemporaryDirectory - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.openai.connect_worker') - -def base(app: FastAPI) -> OpenAIServing: - # Reuse the existing instance - return tokenization(app) - - -def models(app: FastAPI) -> OpenAIServingModels: - return app.state.openai_serving_models - - -def chat(app: FastAPI) -> Optional[OpenAIServingChat]: - return app.state.openai_serving_chat - - -def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]: - return app.state.openai_serving_completion - -def tokenization(app: FastAPI) -> OpenAIServingTokenization: - return app.state.openai_serving_tokenization - - -def bytes_to_headers(bytes_data: bytes) -> httpx.Headers: - headers_dict = json.loads(bytes_data.decode()) - return httpx.Headers(headers_dict) - -async def worker_routine(worker_addr: str, app: FastAPI, - context: zmq.asyncio.Context, i: int = 0): - """Worker routine""" - try: - # Socket to talk to dispatcher - socket = context.socket(zmq.DEALER) - worker_identity = f"worker-{i}-{uuid.uuid4()}" - socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) - socket.connect(worker_addr) - logger.info("%s started at %s", worker_identity, worker_addr) - while True: - identity, url, header, body = await socket.recv_multipart() - logger.info("worker-%d Received request identity: [ %s ]", - i, identity.decode()) - url_str = url.decode() - logger.info("worker-%d Received request url: [ %s ]", - i, url_str) - headers = bytes_to_headers(header) - logger.info("worker-%d Received request headers: [ %s ]", - i, headers) - body_json = json.loads(body.decode()) - logger.info("worker-%d Received request body: [ %s ]", - i, body_json) - logger.info("worker-%d Calling OpenAI API", i) - completionRequest = CompletionRequest(**body_json) - createRequest = create_request(url_str, "POST", body_json, headers) - generator = await create_completion(app, completionRequest, - createRequest) - context_type_json = b"application/json" - if isinstance(generator, ErrorResponse): - content = generator.model_dump_json() - context_json = json.loads(content) - context_json.append("status_code", generator.code) - await socket.send_multipart([identity, context_type_json, - json.dumps(context_json).encode('utf-8')]) - elif isinstance(generator, CompletionResponse): - await socket.send_multipart([identity, - context_type_json, - json.dumps(generator.model_dump()).encode('utf-8')]) - else: - async for chunk in generator: - await socket.send_multipart([identity, - b"text/event-stream", - chunk.encode('utf-8')]) - except Exception as e: - logger.error("Error in worker routine: %s worker-%d", e, i) - logger.error(traceback.format_exc()) - -async def create_completion(app: FastAPI, request: CompletionRequest, - raw_request: Request): - handler = completion(app) - logger.info("zmq request post: %s", request) - if handler is None: - return base(app).create_error_response( - message="The model does not support Completions API") - - generator = await handler.create_completion(request, raw_request) - logger.info("zmq request end post: %s", generator) - return generator - - -def create_request(path: str, method: str, body: dict, - headers: httpx.Headers) -> Request: - scope = { - 'type': 'http', - 'http_version': '1.1', - 'method': method, - 'path': path, - 'headers': list(headers.items()) if headers else [], - } - if body: - scope['body'] = json.dumps(body) - async def receive(): - return { - 'type': 'http.request', - 'body': scope.get('body', b''), - } - async def send(message): - pass - return Request(scope, receive=receive, send=send) - - -if __name__ == "__main__": - print(bytes_to_headers(b'{"Content-Type": "application/json"}')) diff --git a/vllm/entrypoints/openai/zmq_server.py b/vllm/entrypoints/openai/zmq_server.py new file mode 100644 index 0000000000000..b5404a7d766ae --- /dev/null +++ b/vllm/entrypoints/openai/zmq_server.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +import os +import signal +import traceback +from argparse import Namespace + +import zmq +import zmq.asyncio +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.protocol import (CompletionRequest, + CompletionResponse, + ErrorResponse) +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.logger import init_logger +from vllm.utils import set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger('vllm.entrypoints.openai.zmq_server') + +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_ERROR = "error" +CONTENT_TYPE_STREAM = "text/event-stream" + +openai_serving_completion: OpenAIServingCompletion +openai_serving_models: OpenAIServingModels + + +async def log_stats(running_requests: set[asyncio.Task]): + while True: + logger.info("Running requests: %d", len(running_requests)) + await asyncio.sleep(10) + + +def _cleanup_ipc_path(server_addr: str): + socket_path = server_addr.replace("ipc://", "") + logger.info("cleaning up local IPC socket file %s", socket_path) + if os.path.exists(socket_path): + os.remove(socket_path) + + +async def serve_zmq(arg) -> None: + """Server routine""" + logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg, + arg.zmq_server_addr) + # different zmq context can't communicate use inproc + server_addr = f"ipc://{arg.zmq_server_addr}" + try: + # Prepare our context and sockets + context = zmq.asyncio.Context() + socket = context.socket(zmq.ROUTER) + # unlimited HWM + hwm_limit = 0 + + socket.bind(server_addr) + socket.setsockopt(zmq.SNDHWM, hwm_limit) + socket.setsockopt(zmq.RCVHWM, hwm_limit) + + running_requests: set[asyncio.Task] = set() + logger.info("zmq Server started at %s", server_addr) + asyncio.create_task(log_stats(running_requests)) + + while True: + try: + logger.debug("zmq Server waiting for request") + # get new request from the client + identity, request_id, body = await socket.recv_multipart() + # launch request handler coroutine + task = asyncio.create_task( + worker_routine(identity, request_id, body, socket)) + running_requests.add(task) + task.add_done_callback(running_requests.discard) + except zmq.ZMQError as e: + logger.error("ZMQError: %s", e) + break + except Exception as e: + logger.error("Unexpected error: %s", e) + break + except KeyboardInterrupt: + logger.info("KeyboardInterrupt received, exiting") + finally: + # Clean up resources + for task in running_requests: + task.cancel() + await asyncio.gather(*running_requests, return_exceptions=True) + socket.close() + context.destroy(linger=0) + _cleanup_ipc_path(server_addr) + + +async def run_zmq_server(args) -> None: + logger.info("vLLM zmq server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + + model_config = await engine_client.get_model_config() + await init_state(engine_client, model_config, args) + logger.info("init_state successful") + await serve_zmq(args) + + +async def init_state( + engine_client: EngineClient, + model_config: ModelConfig, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + global openai_serving_models + openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + await openai_serving_models.init_static_loras() + + global openai_serving_completion + openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + openai_serving_models, + request_logger=None, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) + + +async def worker_routine(identity: bytes, request_id: bytes, body: bytes, + socket: zmq.asyncio.Socket): + """Worker routine""" + try: + body_json = json.loads(body.decode()) + request_id_str = request_id.decode() + logger.debug("receive request: %s from %s, request_id: %s", body_json, + identity.decode(), request_id_str) + + completionRequest = CompletionRequest(**body_json) + generator = await create_completion(completionRequest, None) + content_type_json = CONTENT_TYPE_JSON.encode('utf-8') + content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8') + if isinstance(generator, ErrorResponse): + content = json.loads(generator.model_dump_json()) + content.update({"status_code": generator.code}) + logger.debug("send ErrorResponse %s", json.dumps(content)) + await socket.send_multipart([ + identity, request_id, content_type_json, + json.dumps(content).encode('utf-8') + ]) + elif isinstance(generator, CompletionResponse): + logger.debug("send CompletionResponse %s", + json.dumps(generator.model_dump())) + await socket.send_multipart([ + identity, request_id, content_type_json, + json.dumps(generator.model_dump()).encode('utf-8') + ]) + else: + async for chunk in generator: + logger.debug( + "send chunk identity: %s, request_id: %s, chunk: %s", + identity.decode(), request_id.decode(), chunk) + await socket.send_multipart([ + identity, request_id, content_type_stream, + chunk.encode('utf-8') + ]) + except Exception as e: + logger.error("Error in worker routine: %s request_id: %s", e, + request_id_str) + logger.error(traceback.format_exc()) + content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8') + logger.debug("send ErrorResponse %s", str(e)) + await socket.send_multipart([ + identity, request_id, + CONTENT_TYPE_ERROR.encode('utf-8'), + str(e).encode('utf-8') + ]) + + +async def create_completion(request: CompletionRequest, raw_request: Request): + logger.debug("zmq request post: %s", request) + generator = await openai_serving_completion.create_completion( + request, raw_request) + logger.debug("zmq request end post") + return generator