mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 22:47:11 +08:00
refactor disagg
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
4f13e89143
commit
912031ceb5
@ -26,6 +26,14 @@ cleanup() {
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
# install quart first -- required for disagg prefill proxy serve
|
||||
if python3 -c "import quart" &> /dev/null; then
|
||||
echo "Quart is already installed."
|
||||
else
|
||||
echo "Quart is not installed. Installing..."
|
||||
python3 -m pip install quart
|
||||
fi
|
||||
|
||||
# a function that waits vLLM server to start
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
@ -41,7 +49,6 @@ wait_for_server() {
|
||||
# prefilling instance, which is the KV producer
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
||||
--port 8100 \
|
||||
--zmq-server-port 7010 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--trust-remote-code \
|
||||
@ -51,25 +58,13 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
|
||||
# decoding instance, which is the KV consumer
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
|
||||
--port 8200 \
|
||||
--zmq-server-port 7011 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--trust-remote-code \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
|
||||
|
||||
# launch a proxy server that opens the service at port 8000
|
||||
# the workflow of this proxy:
|
||||
# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens
|
||||
# to 1
|
||||
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
|
||||
# instance (via zmq port 7011)
|
||||
vllm connect --port 8000 \
|
||||
--prefill-addr 127.0.0.1:7010 \
|
||||
--decode-addr 127.0.0.1:7011 &
|
||||
|
||||
# wait until prefill, decode instances and proxy are ready
|
||||
wait_for_server 8000
|
||||
# wait until prefill and decode instances are ready
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
|
||||
@ -118,4 +113,4 @@ echo "Output of first request: $output1"
|
||||
echo "Output of second request: $output2"
|
||||
|
||||
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
|
||||
echo ""
|
||||
echo ""
|
||||
113
examples/online_serving/disaggregated_prefill_zmq.sh
Normal file
113
examples/online_serving/disaggregated_prefill_zmq.sh
Normal file
@ -0,0 +1,113 @@
|
||||
#!/bin/bash
|
||||
# This file demonstrates the example usage of disaggregated prefilling with ZMQ
|
||||
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
|
||||
# and then transfer the KV cache between them.
|
||||
|
||||
set -xe
|
||||
|
||||
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
|
||||
sleep 1
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'cleanup' INT
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9
|
||||
pkill -f python
|
||||
echo "Cleanup complete. Exiting."
|
||||
exit 0
|
||||
}
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
# a function that waits vLLM connect to start
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
# a function that waits vLLM disagg to start
|
||||
wait_for_disagg_server() {
|
||||
local log_file=$1
|
||||
timeout 1200 bash -c "
|
||||
until grep -q 'zmq Server started at' $log_file; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
|
||||
# You can also adjust --kv-ip and --kv-port for distributed inference.
|
||||
|
||||
# prefilling instance, which is the KV producer
|
||||
CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--zmq-server-addr testipc0 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
|
||||
|
||||
# decoding instance, which is the KV consumer
|
||||
CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--zmq-server-addr testipc1 \
|
||||
--max-model-len 100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > vllm_disagg_decode.log 2>&1 &
|
||||
|
||||
# launch a proxy server that opens the service at port 8000
|
||||
# the workflow of this proxy:
|
||||
# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens
|
||||
# to 1
|
||||
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
|
||||
# instance (via zmq addr testipc1)
|
||||
vllm connect --port 8000 \
|
||||
--prefill-addr testipc0 \
|
||||
--decode-addr testipc1 &
|
||||
|
||||
# wait until prefill, decode instances and proxy are ready
|
||||
wait_for_server 8000
|
||||
wait_for_disagg_server vllm_disagg_prefill.log
|
||||
wait_for_disagg_server vllm_disagg_decode.log
|
||||
|
||||
# serve two example requests
|
||||
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"prompt": "Santa Clara is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9
|
||||
pkill -f python
|
||||
|
||||
echo ""
|
||||
|
||||
sleep 1
|
||||
|
||||
# Print the outputs of the curl requests
|
||||
echo ""
|
||||
echo "Output of first request: $output1"
|
||||
echo "Output of second request: $output2"
|
||||
|
||||
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
|
||||
echo ""
|
||||
72
vllm/entrypoints/cli/connect.py
Normal file
72
vllm/entrypoints/cli/connect.py
Normal file
@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
|
||||
import uvloop
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.disagg_connector import run_disagg_connector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class ConnectSubcommand(CLISubcommand):
|
||||
"""The `connect` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "connect"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
uvloop.run(run_disagg_connector(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_connect_parsed_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
connect_parser = subparsers.add_parser(
|
||||
"connect",
|
||||
help=
|
||||
"Start the vLLM OpenAI Compatible API Server which connect to other"
|
||||
"servers disaggreate prefill and decode",
|
||||
usage="vllm connect [options]")
|
||||
|
||||
return make_connect_arg_parser(connect_parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ConnectSubcommand()]
|
||||
|
||||
|
||||
def make_connect_arg_parser(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=7001,
|
||||
help="The fastapi server port default 7001")
|
||||
# support ipc only now, support tcp later(with auth)
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
choices=["ipc"],
|
||||
default="ipc",
|
||||
help="The zmq socket addr protocol IPC (Inter-Process Communication)")
|
||||
# security concern only support ipc now
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_connect_parsed_args(args: argparse.Namespace):
|
||||
"""Quick checks for connect args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "connect":
|
||||
return
|
||||
115
vllm/entrypoints/cli/disagg.py
Normal file
115
vllm/entrypoints/cli/disagg.py
Normal file
@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.cli_args import (LoRAParserAction,
|
||||
PromptAdapterParserAction)
|
||||
from vllm.entrypoints.openai.zmq_server import run_zmq_server
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class DisaggSubcommand(CLISubcommand):
|
||||
"""The `disagg` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "disagg"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
# The default value of `--model`
|
||||
if not args.model_tag:
|
||||
raise ValueError(
|
||||
"With `vllm disagg`, you should provide the model as a "
|
||||
"positional argument instead of via the `--model` option.")
|
||||
|
||||
# EngineArgs expects the model name to be passed as --model.
|
||||
args.model = args.model_tag
|
||||
|
||||
asyncio.run(run_zmq_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_disagg_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
disagg_parser = subparsers.add_parser(
|
||||
"disagg",
|
||||
help="Start the vLLM OpenAI Compatible API zmq server",
|
||||
usage="vllm disagg <model_tag> [options]")
|
||||
|
||||
return make_disagg_arg_parser(disagg_parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [DisaggSubcommand()]
|
||||
|
||||
|
||||
def make_disagg_arg_parser(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"model_tag",
|
||||
type=str,
|
||||
help=
|
||||
"The model tag to use for the vLLM OpenAI Compatible API zmq server.")
|
||||
parser.add_argument('--zmq-server-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address to serve the zmq server on.')
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When ``--max-logprobs`` is specified, represents single tokens "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified.")
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): ``'name=path'`` "
|
||||
"Example (new format): "
|
||||
"``{\"name\": \"name\", \"path\": \"lora_path\", "
|
||||
"\"base_model_name\": \"id\"}``")
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_disagg_args(args: argparse.Namespace):
|
||||
"""Quick checks for model disagg args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "disagg":
|
||||
return
|
||||
|
||||
# Enable reasoning needs a reasoning parser to be valid
|
||||
if args.enable_reasoning and not args.reasoning_parser:
|
||||
raise TypeError("Error: --enable-reasoning requires "
|
||||
"--reasoning-parser")
|
||||
@ -6,6 +6,8 @@ import signal
|
||||
import sys
|
||||
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.connect
|
||||
import vllm.entrypoints.cli.disagg
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.serve
|
||||
import vllm.version
|
||||
@ -18,6 +20,8 @@ CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.disagg,
|
||||
vllm.entrypoints.cli.connect,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -3,44 +3,69 @@
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
from typing import Union
|
||||
|
||||
import uvicorn
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR,
|
||||
CONTENT_TYPE_JSON,
|
||||
CONTENT_TYPE_STREAM)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# default prefill and decode addr
|
||||
time_out = 180
|
||||
socket_prefill_num = 100
|
||||
socket_decode_num = 100
|
||||
context_type_json = "application/json"
|
||||
context_type_error = "error"
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.disagg_connector')
|
||||
|
||||
TIME_OUT = 5
|
||||
X_REQUEST_ID_KEY = "X-Request-Id"
|
||||
|
||||
# communication between output handlers and execute_task_async
|
||||
request_queues: dict[str, asyncio.Queue]
|
||||
|
||||
|
||||
async def log_stats(request_queues: dict[str, asyncio.Queue]):
|
||||
while True:
|
||||
logger.info("Running requests: %d", len(request_queues))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
# create async socket use ZMQ_DEALER
|
||||
async def create_socket(url: str,
|
||||
zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket:
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"connector-{uuid.uuid4()}"
|
||||
sock.setsockopt(zmq.IDENTITY, identity.encode())
|
||||
sock.connect(url)
|
||||
logger.info("%s started at %s", identity, url)
|
||||
return sock
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# create socket pool with prefill and decode
|
||||
logger.info("start create_socket_pool")
|
||||
logger.info("start connect zmq server")
|
||||
app.state.zmqctx = zmq.asyncio.Context()
|
||||
app.state.sockets_prefill = await create_socket_pool(
|
||||
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_prefill")
|
||||
app.state.sockets_decode = await create_socket_pool(
|
||||
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_decode")
|
||||
app.state.prefill_socket = await create_socket(app.state.prefill_addr,
|
||||
zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socke sockets_prefill")
|
||||
app.state.decode_socket = await create_socket(app.state.decode_addr,
|
||||
zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket sockets_decode")
|
||||
global request_queues
|
||||
request_queues = {}
|
||||
asyncio.create_task(prefill_handler(app.state.prefill_socket))
|
||||
asyncio.create_task(decode_handler(app.state.decode_socket))
|
||||
asyncio.create_task(log_stats(request_queues))
|
||||
yield
|
||||
## close zmq context
|
||||
logger.info("shutdown disagg connector")
|
||||
@ -51,59 +76,140 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||
async def create_socket_pool(url: str, num_sockets: int,
|
||||
zmqctx: zmq.asyncio.Context) -> Queue:
|
||||
sockets: Queue[zmq.Socket] = Queue()
|
||||
for i in range(num_sockets):
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
sock.setsockopt(zmq.IDENTITY, identity.encode())
|
||||
sock.connect(url)
|
||||
logger.info("%s started at %s with queue size %s", identity, url,
|
||||
sockets.qsize())
|
||||
await sockets.put(sock)
|
||||
return sockets
|
||||
@app.post('/v1/completions')
|
||||
async def completions(request: Request, background_tasks: BackgroundTasks):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
header = dict(request.headers)
|
||||
request_id = header.get(X_REQUEST_ID_KEY)
|
||||
queue = asyncio.Queue()
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
logger.debug("add X-Request-Id: %s", request_id)
|
||||
header[X_REQUEST_ID_KEY] = request_id
|
||||
logger.debug("X-Request-Id is: %s", request_id)
|
||||
request_queues[request_id] = queue
|
||||
request_data = await request.json()
|
||||
logger.info("Received request_id: %s, request: %s, header: %s",
|
||||
request_id, request_data, header)
|
||||
original_max_tokens = request_data['max_tokens']
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
request_data['max_tokens'] = 1
|
||||
# finish prefill
|
||||
try:
|
||||
prefill_response = await prefill(header, request_data)
|
||||
if isinstance(prefill_response, JSONResponse):
|
||||
return prefill_response
|
||||
logger.debug("finish prefill start decode")
|
||||
request_data['max_tokens'] = original_max_tokens
|
||||
response = await decode(header, request_data)
|
||||
logger.debug("finish decode")
|
||||
except Exception as e:
|
||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||
e)
|
||||
response = JSONResponse({"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=500)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
exc_info = sys.exc_info()
|
||||
logger.error("Error occurred in disagg prefill proxy server")
|
||||
logger.error(e)
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
response = JSONResponse({"error": {
|
||||
"message": str(e)
|
||||
}},
|
||||
status_code=500)
|
||||
return response
|
||||
finally:
|
||||
background_tasks.add_task(cleanup_request_id, request_id)
|
||||
|
||||
|
||||
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
|
||||
while True:
|
||||
try:
|
||||
[request_id, contentType, reply] = await socket.recv_multipart()
|
||||
contentType_str = contentType.decode()
|
||||
reply_str = reply.decode()
|
||||
request_id_str = request_id.decode()
|
||||
logger.debug(
|
||||
"%s socket received result contentType: %s, "
|
||||
"request_id: %s, reply: %s", scene, contentType_str,
|
||||
request_id_str, reply_str)
|
||||
if request_id_str in request_queues:
|
||||
request_queues[request_id_str].put_nowait(
|
||||
(contentType_str, reply_str))
|
||||
if "[DONE]" in reply_str:
|
||||
logger.debug(
|
||||
"%s socket received stop signal request_id: %s", scene,
|
||||
request_id_str)
|
||||
request_queues[request_id_str].put_nowait(
|
||||
(contentType_str, None))
|
||||
else:
|
||||
logger.debug(
|
||||
"%s socket received but request_id not found discard: %s",
|
||||
scene, request_id_str)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("%s handler error: %s", scene, e)
|
||||
|
||||
|
||||
# prefill handler
|
||||
async def prefill_handler(prefill_socket: zmq.asyncio.Socket):
|
||||
await socket_recv_handler(prefill_socket, "prefill")
|
||||
|
||||
|
||||
# decode handler
|
||||
async def decode_handler(decode_socket: zmq.asyncio.Socket):
|
||||
await socket_recv_handler(decode_socket, "decode")
|
||||
|
||||
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(route: str, headers: dict, request: dict,
|
||||
sockets: Queue):
|
||||
sock: zmq.Socket = await sockets.get()
|
||||
async def execute_task_async(headers: dict, request: dict,
|
||||
socket: zmq.asyncio.Socket):
|
||||
try:
|
||||
request_id = headers.get(X_REQUEST_ID_KEY)
|
||||
requestBody = json.dumps(request)
|
||||
headersJson = json.dumps(headers)
|
||||
logger.info("Sending requestBody: %s to %s with headers: %s",
|
||||
requestBody, route, headersJson)
|
||||
await asyncio.wait_for(sock.send_multipart(
|
||||
[route.encode(),
|
||||
headersJson.encode(),
|
||||
requestBody.encode()]),
|
||||
timeout=time_out)
|
||||
logger.info("Sending requestBody: %s", requestBody)
|
||||
socket.send_multipart([request_id.encode(), requestBody.encode()])
|
||||
logger.debug("Sent end")
|
||||
queue = request_queues[request_id]
|
||||
while True:
|
||||
logger.debug("Waiting for reply")
|
||||
[contentType,
|
||||
reply] = await asyncio.wait_for(sock.recv_multipart(),
|
||||
timeout=time_out)
|
||||
contentType_str = contentType.decode()
|
||||
reply_str = reply.decode()
|
||||
logger.debug("Received result: %s, %s", contentType_str, reply_str)
|
||||
yield (contentType_str, reply_str)
|
||||
if context_type_json == contentType_str:
|
||||
logger.debug("Received %s message, return socket",
|
||||
contentType_str)
|
||||
(contentType,
|
||||
reply) = await asyncio.wait_for(queue.get(), TIME_OUT)
|
||||
logger.debug("Received result: %s, %s", contentType, reply)
|
||||
if reply is None:
|
||||
logger.debug("Received stop signal, request_id: %s",
|
||||
request_id)
|
||||
break
|
||||
if "[DONE]" in reply_str:
|
||||
logger.debug("Received stop signal, return socket")
|
||||
yield (contentType, reply)
|
||||
if contentType == CONTENT_TYPE_JSON:
|
||||
logger.debug("Received %s message, request_id: %s",
|
||||
contentType, request_id)
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("Timeout, return socket: %s",
|
||||
sock.getsockopt(zmq.IDENTITY))
|
||||
yield (context_type_error, "System Error")
|
||||
yield (CONTENT_TYPE_ERROR, "System Error")
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
logger.debug("request_id: %s, execute_task_async end", request_id)
|
||||
|
||||
|
||||
async def prefill(header: dict,
|
||||
original_request_data: dict) -> Union[JSONResponse, bool]:
|
||||
logger.debug("start prefill")
|
||||
generator = execute_task_async(header, original_request_data,
|
||||
app.state.prefill_socket)
|
||||
async for contentType, reply in generator:
|
||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
||||
if contentType == CONTENT_TYPE_ERROR:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
return True
|
||||
|
||||
|
||||
async def generate_stream_response(fisrt_reply: str,
|
||||
@ -113,78 +219,34 @@ async def generate_stream_response(fisrt_reply: str,
|
||||
yield reply
|
||||
|
||||
|
||||
async def prefill(route: str, header: dict, original_request_data: dict):
|
||||
logger.info("start prefill")
|
||||
generator = execute_task_async(route, header, original_request_data,
|
||||
app.state.sockets_prefill)
|
||||
async for contentType, reply in generator:
|
||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
||||
if context_type_error == contentType:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
return True
|
||||
|
||||
|
||||
async def decode(route: str, header: dict, original_request_data: dict):
|
||||
async def decode(
|
||||
header: dict,
|
||||
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]:
|
||||
logger.info("start decode")
|
||||
generator = execute_task_async(route, header, original_request_data,
|
||||
app.state.sockets_decode)
|
||||
generator = execute_task_async(header, original_request_data,
|
||||
app.state.decode_socket)
|
||||
|
||||
async for contentType, reply in generator:
|
||||
logger.debug("contentType: %s, reply: %s", contentType, reply)
|
||||
if context_type_error == contentType:
|
||||
if contentType == CONTENT_TYPE_ERROR:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
elif context_type_json == contentType:
|
||||
elif contentType == CONTENT_TYPE_JSON:
|
||||
return JSONResponse(reply)
|
||||
else:
|
||||
return StreamingResponse(generate_stream_response(
|
||||
reply, generator),
|
||||
media_type="text/event-stream")
|
||||
media_type=CONTENT_TYPE_STREAM)
|
||||
|
||||
|
||||
@app.post('/v1/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
x_request_id = str(uuid.uuid4())
|
||||
header = dict(request.headers)
|
||||
if header.get("X-Request-Id") is None:
|
||||
logger.info("add X-Request-Id: %s", x_request_id)
|
||||
header["X-Request-Id"] = x_request_id
|
||||
request_data = await request.json()
|
||||
logger.info("Received request: %s header: %s", request_data, header)
|
||||
original_max_tokens = request_data['max_tokens']
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
request_data['max_tokens'] = 1
|
||||
route = "/v1/completions"
|
||||
# finish prefill
|
||||
try:
|
||||
prefill_response = await prefill(route, header, request_data)
|
||||
if isinstance(prefill_response, JSONResponse):
|
||||
return prefill_response
|
||||
logger.info("finish prefill start decode")
|
||||
request_data['max_tokens'] = original_max_tokens
|
||||
response = await decode(route, header, request_data)
|
||||
logger.info("finish decode")
|
||||
except Exception as e:
|
||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||
e)
|
||||
response = JSONResponse({"error": {"message": str(e)}})
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
logger.error("Error occurred in disagg prefill proxy server")
|
||||
logger.error(e)
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
def cleanup_request_id(request_id: str):
|
||||
if request_id in request_queues:
|
||||
logger.info("del request_id: %s, decode finished", request_id)
|
||||
del request_queues[request_id]
|
||||
|
||||
|
||||
async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
async def run_disagg_connector(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Disaggregate Connector start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
logger.info(args.prefill_addr)
|
||||
@ -192,8 +254,9 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
app.state.decode_addr = f"ipc://{args.decode_addr}"
|
||||
logger.info(
|
||||
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
|
||||
app.state.prefill_addr, app.state.decode_addr, app.state.port)
|
||||
"start connect prefill_addr: %s decode_addr: %s "
|
||||
"zmq server fastapi port: %s", app.state.prefill_addr,
|
||||
app.state.decode_addr, app.state.port)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
@ -208,20 +271,22 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
parser = FlexibleArgumentParser(description="vLLM disagg zmq server.")
|
||||
# This section should be sync with vllm/entrypoints/cli/connect.py for CLI
|
||||
# entrypoints.
|
||||
parser = FlexibleArgumentParser(description="vLLM disagg connect server.")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="The fastapi server port")
|
||||
default=8001,
|
||||
help="The fastapi server port default 8001")
|
||||
# security concern only support ipc now
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The prefill address IP:PORT")
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The decode address IP:PORT")
|
||||
help="The zmq ipc decode address")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -7,16 +7,12 @@ from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import zmq.devices
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.entrypoints.openai.connect_worker import worker_routine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
@ -79,43 +75,6 @@ async def serve_http(app: FastAPI,
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||
zmq_server_port)
|
||||
# different zmq context can't communicate use inproc
|
||||
workers_addr = "ipc://workers"
|
||||
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context.instance()
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(100)
|
||||
]
|
||||
logger.info("zmq tasks: %s", tasks)
|
||||
# thread safety proxy create socket in the background:
|
||||
# https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices
|
||||
thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER)
|
||||
# unlimited HWM
|
||||
hwm_limit = 0
|
||||
thread_proxy.bind_in(clients_addr)
|
||||
thread_proxy.setsockopt_in(zmq.SNDHWM, hwm_limit)
|
||||
thread_proxy.setsockopt_in(zmq.RCVHWM, hwm_limit)
|
||||
thread_proxy.bind_out(workers_addr)
|
||||
thread_proxy.setsockopt_out(zmq.SNDHWM, hwm_limit)
|
||||
thread_proxy.setsockopt_out(zmq.RCVHWM, hwm_limit)
|
||||
thread_proxy.start()
|
||||
await asyncio.gather(*tasks)
|
||||
except KeyboardInterrupt:
|
||||
print("ZMQ Server interrupted")
|
||||
except zmq.ZMQError as e:
|
||||
print("ZMQError:", e)
|
||||
finally:
|
||||
# We never get here but clean up anyhow
|
||||
context.destroy(linger=0)
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http, serve_zmq
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
|
||||
@ -1,139 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.connect_worker')
|
||||
|
||||
def base(app: FastAPI) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(app)
|
||||
|
||||
|
||||
def models(app: FastAPI) -> OpenAIServingModels:
|
||||
return app.state.openai_serving_models
|
||||
|
||||
|
||||
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
|
||||
return app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
|
||||
return app.state.openai_serving_completion
|
||||
|
||||
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
|
||||
return app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
||||
headers_dict = json.loads(bytes_data.decode())
|
||||
return httpx.Headers(headers_dict)
|
||||
|
||||
async def worker_routine(worker_addr: str, app: FastAPI,
|
||||
context: zmq.asyncio.Context, i: int = 0):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
# Socket to talk to dispatcher
|
||||
socket = context.socket(zmq.DEALER)
|
||||
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||
socket.connect(worker_addr)
|
||||
logger.info("%s started at %s", worker_identity, worker_addr)
|
||||
while True:
|
||||
identity, url, header, body = await socket.recv_multipart()
|
||||
logger.info("worker-%d Received request identity: [ %s ]",
|
||||
i, identity.decode())
|
||||
url_str = url.decode()
|
||||
logger.info("worker-%d Received request url: [ %s ]",
|
||||
i, url_str)
|
||||
headers = bytes_to_headers(header)
|
||||
logger.info("worker-%d Received request headers: [ %s ]",
|
||||
i, headers)
|
||||
body_json = json.loads(body.decode())
|
||||
logger.info("worker-%d Received request body: [ %s ]",
|
||||
i, body_json)
|
||||
logger.info("worker-%d Calling OpenAI API", i)
|
||||
completionRequest = CompletionRequest(**body_json)
|
||||
createRequest = create_request(url_str, "POST", body_json, headers)
|
||||
generator = await create_completion(app, completionRequest,
|
||||
createRequest)
|
||||
context_type_json = b"application/json"
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = generator.model_dump_json()
|
||||
context_json = json.loads(content)
|
||||
context_json.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, context_type_json,
|
||||
json.dumps(context_json).encode('utf-8')])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
await socket.send_multipart([identity,
|
||||
context_type_json,
|
||||
json.dumps(generator.model_dump()).encode('utf-8')])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
await socket.send_multipart([identity,
|
||||
b"text/event-stream",
|
||||
chunk.encode('utf-8')])
|
||||
except Exception as e:
|
||||
logger.error("Error in worker routine: %s worker-%d", e, i)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def create_completion(app: FastAPI, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = completion(app)
|
||||
logger.info("zmq request post: %s", request)
|
||||
if handler is None:
|
||||
return base(app).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
logger.info("zmq request end post: %s", generator)
|
||||
return generator
|
||||
|
||||
|
||||
def create_request(path: str, method: str, body: dict,
|
||||
headers: httpx.Headers) -> Request:
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'method': method,
|
||||
'path': path,
|
||||
'headers': list(headers.items()) if headers else [],
|
||||
}
|
||||
if body:
|
||||
scope['body'] = json.dumps(body)
|
||||
async def receive():
|
||||
return {
|
||||
'type': 'http.request',
|
||||
'body': scope.get('body', b''),
|
||||
}
|
||||
async def send(message):
|
||||
pass
|
||||
return Request(scope, receive=receive, send=send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))
|
||||
211
vllm/entrypoints/openai/zmq_server.py
Normal file
211
vllm/entrypoints/openai/zmq_server.py
Normal file
@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger('vllm.entrypoints.openai.zmq_server')
|
||||
|
||||
CONTENT_TYPE_JSON = "application/json"
|
||||
CONTENT_TYPE_ERROR = "error"
|
||||
CONTENT_TYPE_STREAM = "text/event-stream"
|
||||
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_models: OpenAIServingModels
|
||||
|
||||
|
||||
async def log_stats(running_requests: set[asyncio.Task]):
|
||||
while True:
|
||||
logger.info("Running requests: %d", len(running_requests))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
def _cleanup_ipc_path(server_addr: str):
|
||||
socket_path = server_addr.replace("ipc://", "")
|
||||
logger.info("cleaning up local IPC socket file %s", socket_path)
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
|
||||
async def serve_zmq(arg) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg,
|
||||
arg.zmq_server_addr)
|
||||
# different zmq context can't communicate use inproc
|
||||
server_addr = f"ipc://{arg.zmq_server_addr}"
|
||||
try:
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context()
|
||||
socket = context.socket(zmq.ROUTER)
|
||||
# unlimited HWM
|
||||
hwm_limit = 0
|
||||
|
||||
socket.bind(server_addr)
|
||||
socket.setsockopt(zmq.SNDHWM, hwm_limit)
|
||||
socket.setsockopt(zmq.RCVHWM, hwm_limit)
|
||||
|
||||
running_requests: set[asyncio.Task] = set()
|
||||
logger.info("zmq Server started at %s", server_addr)
|
||||
asyncio.create_task(log_stats(running_requests))
|
||||
|
||||
while True:
|
||||
try:
|
||||
logger.debug("zmq Server waiting for request")
|
||||
# get new request from the client
|
||||
identity, request_id, body = await socket.recv_multipart()
|
||||
# launch request handler coroutine
|
||||
task = asyncio.create_task(
|
||||
worker_routine(identity, request_id, body, socket))
|
||||
running_requests.add(task)
|
||||
task.add_done_callback(running_requests.discard)
|
||||
except zmq.ZMQError as e:
|
||||
logger.error("ZMQError: %s", e)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error: %s", e)
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
logger.info("KeyboardInterrupt received, exiting")
|
||||
finally:
|
||||
# Clean up resources
|
||||
for task in running_requests:
|
||||
task.cancel()
|
||||
await asyncio.gather(*running_requests, return_exceptions=True)
|
||||
socket.close()
|
||||
context.destroy(linger=0)
|
||||
_cleanup_ipc_path(server_addr)
|
||||
|
||||
|
||||
async def run_zmq_server(args) -> None:
|
||||
logger.info("vLLM zmq server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
set_ulimit()
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
await init_state(engine_client, model_config, args)
|
||||
logger.info("init_state successful")
|
||||
await serve_zmq(args)
|
||||
|
||||
|
||||
async def init_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
global openai_serving_models
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
await openai_serving_models.init_static_loras()
|
||||
|
||||
global openai_serving_completion
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=None,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
|
||||
|
||||
async def worker_routine(identity: bytes, request_id: bytes, body: bytes,
|
||||
socket: zmq.asyncio.Socket):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
body_json = json.loads(body.decode())
|
||||
request_id_str = request_id.decode()
|
||||
logger.debug("receive request: %s from %s, request_id: %s", body_json,
|
||||
identity.decode(), request_id_str)
|
||||
|
||||
completionRequest = CompletionRequest(**body_json)
|
||||
generator = await create_completion(completionRequest, None)
|
||||
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
|
||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = json.loads(generator.model_dump_json())
|
||||
content.update({"status_code": generator.code})
|
||||
logger.debug("send ErrorResponse %s", json.dumps(content))
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_json,
|
||||
json.dumps(content).encode('utf-8')
|
||||
])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
logger.debug("send CompletionResponse %s",
|
||||
json.dumps(generator.model_dump()))
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_json,
|
||||
json.dumps(generator.model_dump()).encode('utf-8')
|
||||
])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.debug(
|
||||
"send chunk identity: %s, request_id: %s, chunk: %s",
|
||||
identity.decode(), request_id.decode(), chunk)
|
||||
await socket.send_multipart([
|
||||
identity, request_id, content_type_stream,
|
||||
chunk.encode('utf-8')
|
||||
])
|
||||
except Exception as e:
|
||||
logger.error("Error in worker routine: %s request_id: %s", e,
|
||||
request_id_str)
|
||||
logger.error(traceback.format_exc())
|
||||
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
|
||||
logger.debug("send ErrorResponse %s", str(e))
|
||||
await socket.send_multipart([
|
||||
identity, request_id,
|
||||
CONTENT_TYPE_ERROR.encode('utf-8'),
|
||||
str(e).encode('utf-8')
|
||||
])
|
||||
|
||||
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
logger.debug("zmq request post: %s", request)
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
request, raw_request)
|
||||
logger.debug("zmq request end post")
|
||||
return generator
|
||||
Loading…
x
Reference in New Issue
Block a user