mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 09:27:03 +08:00
updated
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
120bbdfd82
commit
85687b43e7
@ -1,72 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
|
||||
import uvloop
|
||||
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.disagg_connector import run_disagg_connector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class ConnectSubcommand(CLISubcommand):
|
||||
"""The `connect` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "connect"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
uvloop.run(run_disagg_connector(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_connect_parsed_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
connect_parser = subparsers.add_parser(
|
||||
"connect",
|
||||
help=
|
||||
"Start the vLLM OpenAI Compatible API Server which connect to other"
|
||||
"servers disaggreate prefill and decode",
|
||||
usage="vllm connect [options]")
|
||||
|
||||
return make_connect_arg_parser(connect_parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ConnectSubcommand()]
|
||||
|
||||
|
||||
def make_connect_arg_parser(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=7001,
|
||||
help="The fastapi server port default 7001")
|
||||
# support ipc only now, support tcp later(with auth)
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
choices=["ipc"],
|
||||
default="ipc",
|
||||
help="The zmq socket addr protocol IPC (Inter-Process Communication)")
|
||||
# security concern only support ipc now
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_connect_parsed_args(args: argparse.Namespace):
|
||||
"""Quick checks for connect args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "connect":
|
||||
return
|
||||
@ -1,115 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.cli_args import (LoRAParserAction,
|
||||
PromptAdapterParserAction)
|
||||
from vllm.entrypoints.openai.zmq_server import run_zmq_server
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class DisaggSubcommand(CLISubcommand):
|
||||
"""The `disagg` subcommand for the vLLM CLI. """
|
||||
|
||||
def __init__(self):
|
||||
self.name = "disagg"
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def cmd(args: argparse.Namespace) -> None:
|
||||
# The default value of `--model`
|
||||
if not args.model_tag:
|
||||
raise ValueError(
|
||||
"With `vllm disagg`, you should provide the model as a "
|
||||
"positional argument instead of via the `--model` option.")
|
||||
|
||||
# EngineArgs expects the model name to be passed as --model.
|
||||
args.model = args.model_tag
|
||||
|
||||
asyncio.run(run_zmq_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_disagg_args(args)
|
||||
|
||||
def subparser_init(
|
||||
self,
|
||||
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||
disagg_parser = subparsers.add_parser(
|
||||
"disagg",
|
||||
help="Start the vLLM OpenAI Compatible API zmq server",
|
||||
usage="vllm disagg <model_tag> [options]")
|
||||
|
||||
return make_disagg_arg_parser(disagg_parser)
|
||||
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [DisaggSubcommand()]
|
||||
|
||||
|
||||
def make_disagg_arg_parser(
|
||||
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"model_tag",
|
||||
type=str,
|
||||
help=
|
||||
"The model tag to use for the vLLM OpenAI Compatible API zmq server.")
|
||||
parser.add_argument('--zmq-server-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address to serve the zmq server on.')
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When ``--max-logprobs`` is specified, represents single tokens "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified.")
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): ``'name=path'`` "
|
||||
"Example (new format): "
|
||||
"``{\"name\": \"name\", \"path\": \"lora_path\", "
|
||||
"\"base_model_name\": \"id\"}``")
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_disagg_args(args: argparse.Namespace):
|
||||
"""Quick checks for model disagg args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "disagg":
|
||||
return
|
||||
|
||||
# Enable reasoning needs a reasoning parser to be valid
|
||||
if args.enable_reasoning and not args.reasoning_parser:
|
||||
raise TypeError("Error: --enable-reasoning requires "
|
||||
"--reasoning-parser")
|
||||
@ -6,8 +6,6 @@ import signal
|
||||
import sys
|
||||
|
||||
import vllm.entrypoints.cli.benchmark.main
|
||||
import vllm.entrypoints.cli.connect
|
||||
import vllm.entrypoints.cli.disagg
|
||||
import vllm.entrypoints.cli.openai
|
||||
import vllm.entrypoints.cli.serve
|
||||
import vllm.version
|
||||
@ -20,8 +18,6 @@ CMD_MODULES = [
|
||||
vllm.entrypoints.cli.openai,
|
||||
vllm.entrypoints.cli.serve,
|
||||
vllm.entrypoints.cli.benchmark.main,
|
||||
vllm.entrypoints.cli.disagg,
|
||||
vllm.entrypoints.cli.connect,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionResponse, ErrorResponse)
|
||||
|
||||
@ -55,16 +55,22 @@ async def main(args, **uvicorn_kwargs):
|
||||
logger.info("vLLM Disaggregate Connector Start %s %s", args,
|
||||
uvicorn_kwargs)
|
||||
|
||||
# Avoid dropping requests under high concurrency.
|
||||
set_ulimit()
|
||||
|
||||
# IPC Paths.
|
||||
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
|
||||
# are not using pickle to avoid RCE security flaw.
|
||||
prefill_addr = f"ipc://{args.prefill_addr}"
|
||||
decode_addr = f"ipc://{args.decode_addr}"
|
||||
connector_addr = f"ipc://{args.connector_addr}"
|
||||
|
||||
# Start Engine.
|
||||
with pd_engine_client_ctx_manager(
|
||||
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
|
||||
|
||||
# Initialize App State.
|
||||
model_config = await engine_client.get_model_config()
|
||||
|
||||
# Models.
|
||||
app.state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
@ -73,8 +79,6 @@ async def main(args, **uvicorn_kwargs):
|
||||
model_path=args.model)
|
||||
],
|
||||
)
|
||||
|
||||
# Completions.
|
||||
app.state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
@ -82,10 +86,10 @@ async def main(args, **uvicorn_kwargs):
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
# Init Uvicorn Server. Server.
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
# Run Server.
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import msgspec
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Dict, List, Mapping, Optional
|
||||
|
||||
@ -12,9 +13,7 @@ import zmq.asyncio
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
@ -26,38 +25,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
|
||||
from vllm.utils import Device, make_zmq_socket
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
DEFAULT_MAX_TOKENS = 32000
|
||||
|
||||
# NOTE FOR DEVELOPERS:
|
||||
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
|
||||
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
|
||||
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
|
||||
|
||||
class PDRequest(msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
prompt_token_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
# TODO: support multimodal inputs.
|
||||
|
||||
class PDResponse(msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
success: bool
|
||||
delta_text: Optional[str] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[str] = None
|
||||
logprobs = None # TODO
|
||||
|
||||
|
||||
class PDEngine:
|
||||
"""
|
||||
PDEngine:
|
||||
@ -89,6 +62,8 @@ class PDEngine:
|
||||
self.to_prefill = make_zmq_socket(
|
||||
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
|
||||
self.connector_addr = connector_addr
|
||||
self.decode_addr = decode_addr
|
||||
self.prefill_addr = prefill_addr
|
||||
|
||||
# Background loops (started on first generate()).
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
@ -118,7 +93,6 @@ class PDEngine:
|
||||
revision=self.model_config.tokenizer_revision,
|
||||
truncation_side=self.model_config.truncation_side)
|
||||
self.tokenizer = TokenizerGroup(**init_kwargs)
|
||||
|
||||
|
||||
def shutdown(self):
|
||||
if (ctx := self.ctx) is not None:
|
||||
@ -128,6 +102,14 @@ class PDEngine:
|
||||
if (task := self.output_handler) is not None:
|
||||
task.cancel()
|
||||
|
||||
ipc_paths = [
|
||||
self.connector_addr, self.decode_addr, self.prefill_addr
|
||||
]
|
||||
for path in ipc_paths:
|
||||
socket_path = path.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
async def _run_log_running(self):
|
||||
logger.info("Running requests: %d", len(self.queues))
|
||||
await asyncio.sleep(10.)
|
||||
@ -312,25 +294,3 @@ class PDEngine:
|
||||
|
||||
async def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--connector-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc connector address")
|
||||
parser.add_argument("--prefill-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc prefill address")
|
||||
parser.add_argument("--decode-addr",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The zmq ipc decode address")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
@ -1,225 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import msgpack
|
||||
import signal
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
import uvloop
|
||||
from typing import Optional
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse, ZmqMsgRequest,
|
||||
ZmqMsgResponse)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import set_ulimit
|
||||
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger('vllm.entrypoints.openai.zmq_server')
|
||||
logger = init_logger(__name__)
|
||||
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_models: OpenAIServingModels
|
||||
|
||||
|
||||
async def log_stats(running_requests: set[asyncio.Task]):
|
||||
while True:
|
||||
logger.info("Running requests: %d", len(running_requests))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
def _cleanup_ipc_path(server_addr: str):
|
||||
socket_path = server_addr.replace("ipc://", "")
|
||||
logger.info("cleaning up local IPC socket file %s", socket_path)
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
|
||||
async def serve_zmq(arg) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg,
|
||||
arg.zmq_server_addr)
|
||||
# different zmq context can't communicate use inproc
|
||||
server_addr = f"ipc://{arg.zmq_server_addr}"
|
||||
async def handle_request(
|
||||
request: PDRequest,
|
||||
engine: EngineClient,
|
||||
socket: zmq.asyncio.Socket,
|
||||
encoder: msgpack.Encoder,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
try:
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context()
|
||||
socket = context.socket(zmq.ROUTER)
|
||||
# unlimited HWM
|
||||
hwm_limit = 0
|
||||
# 1) Generate RequestOutputs.
|
||||
async for request_output in engine.generate(
|
||||
prompt=request.prompt_token_ids,
|
||||
sampling_params=request.sampling_params,
|
||||
request_id=request_id):
|
||||
|
||||
socket.bind(server_addr)
|
||||
socket.setsockopt(zmq.SNDHWM, hwm_limit)
|
||||
socket.setsockopt(zmq.RCVHWM, hwm_limit)
|
||||
assert len(request_output.outputs) == 0, "Only support N=1 right now."
|
||||
out = request_output.outputs[0]
|
||||
|
||||
running_requests: set[asyncio.Task] = set()
|
||||
logger.info("zmq Server started at %s", server_addr)
|
||||
asyncio.create_task(log_stats(running_requests))
|
||||
# 2) Convert RequestOutput --> PDResponse.
|
||||
response = PDResponse(
|
||||
request_id=request_id,
|
||||
success=True,
|
||||
text=out.text,
|
||||
token_ids=out.token_ids,
|
||||
finish_reasons=out.finish_reason,
|
||||
stop_reason=out.stop_reason,
|
||||
)
|
||||
response_bytes = encoder(response)
|
||||
|
||||
# 3) Send to Connector.
|
||||
await socket.send(response_bytes, copy=False)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: actual error handling.
|
||||
logger.error("Exception in Worker Routine: %s request_id: %s", e,
|
||||
request_id)
|
||||
response = PDResponse(request_id=request_id, success=False)
|
||||
response_bytes = encoder(response)
|
||||
await socket.send(response, copy=False)
|
||||
|
||||
async def run_server(args, engine: EngineClient):
|
||||
"""Get Requests and Handle Them."""
|
||||
running_requests: set[asyncio.Task] = set()
|
||||
decoder = msgpack.Decoder(PDRequest)
|
||||
encoder = msgpack.Encoder()
|
||||
|
||||
ctx: Optional[zmq.asyncio.Context] = None
|
||||
try:
|
||||
# IPC Setup.
|
||||
ctx = zmq.asyncio.Context()
|
||||
from_connector = make_zmq_socket(
|
||||
ctx, f"ipc://{args.server_addr}", zmq.constants.PULL)
|
||||
to_connector = make_zmq_socket(
|
||||
ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH)
|
||||
|
||||
# Main Loop.
|
||||
while True:
|
||||
try:
|
||||
logger.debug("zmq Server waiting for request")
|
||||
# get new request from the client
|
||||
message_parts = await socket.recv_multipart()
|
||||
logger.debug("received request: %s", message_parts)
|
||||
logger.debug("received len: %d", len(message_parts))
|
||||
identity, body = message_parts[0], message_parts[1]
|
||||
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
|
||||
# launch request handler coroutine
|
||||
task = asyncio.create_task(
|
||||
worker_routine(identity, zmq_msg_request, socket))
|
||||
running_requests.add(task)
|
||||
task.add_done_callback(running_requests.discard)
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("ZMQError: %s", e)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("Unexpected error: %s", e)
|
||||
break
|
||||
# 1) Get request from the Connector.
|
||||
pd_request_bytes = await from_connector.recv().buffer
|
||||
pd_request = decoder(pd_request_bytes)
|
||||
|
||||
# 2) Launch a coroutine to handle the request.
|
||||
task = asyncio.create_task(handle_request(
|
||||
pd_request, engine, to_connector, encoder))
|
||||
running_requests.add(task)
|
||||
task.add_done_callback(running_requests.discard)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("KeyboardInterrupt received, exiting")
|
||||
logger.debug("Worker server loop interrupted.")
|
||||
|
||||
finally:
|
||||
# Clean up resources
|
||||
for task in running_requests:
|
||||
task.cancel()
|
||||
await asyncio.gather(*running_requests, return_exceptions=True)
|
||||
socket.close()
|
||||
context.destroy(linger=0)
|
||||
_cleanup_ipc_path(server_addr)
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
async def run_zmq_server(args) -> None:
|
||||
logger.info("vLLM zmq server version %s", VLLM_VERSION)
|
||||
async def main(args) -> None:
|
||||
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
# Workaround to avoid footguns where uvicorn drops requests
|
||||
# with too many concurrent requests active due to ulimit.
|
||||
set_ulimit()
|
||||
|
||||
# Interrupt on sigterm during initialization.
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
async with build_async_engine_client(args) as engine:
|
||||
await run_server(args, engine)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
await init_state(engine_client, model_config, args)
|
||||
logger.info("init_state successful")
|
||||
await serve_zmq(args)
|
||||
|
||||
|
||||
async def init_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
global openai_serving_models
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
await openai_serving_models.init_static_loras()
|
||||
|
||||
global openai_serving_completion
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=None,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
|
||||
|
||||
async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
request_id = zmq_msg_request.request_id
|
||||
logger.debug("receive request: %s from %s, request_id: %s",
|
||||
zmq_msg_request.model_dump_json(), identity.decode(),
|
||||
request_id)
|
||||
if isinstance(zmq_msg_request.body, CompletionRequest):
|
||||
await create_completion(identity, zmq_msg_request, socket)
|
||||
else:
|
||||
logger.error("Error in worker routine: %s request_id: %s",
|
||||
"unsupported request type", request_id)
|
||||
raise Exception("unsupported request type")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in worker routine: %s request_id: %s", e,
|
||||
request_id)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.debug("send ErrorResponse %s", str(e))
|
||||
await socket.send_multipart([
|
||||
identity,
|
||||
ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body=json.dumps({
|
||||
"content":
|
||||
"unsupported request type",
|
||||
"status_code":
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
})).model_dump_json().encode(),
|
||||
])
|
||||
|
||||
|
||||
async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
|
||||
socket: zmq.asyncio.Socket):
|
||||
request: CompletionRequest = zmq_msg_request.body
|
||||
logger.debug("zmq request post: %s", request.model_dump_json())
|
||||
generator = await openai_serving_completion.create_completion(request)
|
||||
logger.debug("zmq request end post")
|
||||
request_id = zmq_msg_request.request_id
|
||||
if isinstance(generator, (ErrorResponse, CompletionResponse)):
|
||||
logger.debug("send response %s", generator.model_dump_json())
|
||||
if isinstance(generator, ErrorResponse):
|
||||
body = json.dumps({
|
||||
"content": generator.model_dump(),
|
||||
"status_code": generator.code
|
||||
})
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
body = json.dumps({"content": generator.model_dump()})
|
||||
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body_type="response",
|
||||
body=body)
|
||||
await socket.send_multipart(
|
||||
[identity, zmq_msg_response.model_dump_json().encode()])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
|
||||
type=zmq_msg_request.type,
|
||||
body=chunk)
|
||||
if "data: [DONE]" not in chunk:
|
||||
zmq_msg_response.stop = False
|
||||
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
|
||||
identity.decode(), request_id, chunk)
|
||||
await socket.send_multipart(
|
||||
[identity,
|
||||
zmq_msg_response.model_dump_json().encode()])
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument('--connector-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the connector.')
|
||||
parser.add_argument('--worker-addr',
|
||||
type=str,
|
||||
required=True,
|
||||
help='The address of the worker.')
|
||||
AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
uvloop.run(main(args))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user