Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-03-22 17:00:46 -04:00
parent 120bbdfd82
commit 85687b43e7
6 changed files with 118 additions and 444 deletions

View File

@ -1,72 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import uvloop
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.disagg_connector import run_disagg_connector
from vllm.utils import FlexibleArgumentParser
class ConnectSubcommand(CLISubcommand):
"""The `connect` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "connect"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
uvloop.run(run_disagg_connector(args))
def validate(self, args: argparse.Namespace) -> None:
validate_connect_parsed_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
connect_parser = subparsers.add_parser(
"connect",
help=
"Start the vLLM OpenAI Compatible API Server which connect to other"
"servers disaggreate prefill and decode",
usage="vllm connect [options]")
return make_connect_arg_parser(connect_parser)
def cmd_init() -> list[CLISubcommand]:
return [ConnectSubcommand()]
def make_connect_arg_parser(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--port",
type=int,
default=7001,
help="The fastapi server port default 7001")
# support ipc only now, support tcp later(with auth)
parser.add_argument(
"--protocol",
type=str,
choices=["ipc"],
default="ipc",
help="The zmq socket addr protocol IPC (Inter-Process Communication)")
# security concern only support ipc now
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
return parser
def validate_connect_parsed_args(args: argparse.Namespace):
"""Quick checks for connect args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "connect":
return

View File

@ -1,115 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import (LoRAParserAction,
PromptAdapterParserAction)
from vllm.entrypoints.openai.zmq_server import run_zmq_server
from vllm.utils import FlexibleArgumentParser
class DisaggSubcommand(CLISubcommand):
"""The `disagg` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "disagg"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# The default value of `--model`
if not args.model_tag:
raise ValueError(
"With `vllm disagg`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
asyncio.run(run_zmq_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_disagg_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
disagg_parser = subparsers.add_parser(
"disagg",
help="Start the vLLM OpenAI Compatible API zmq server",
usage="vllm disagg <model_tag> [options]")
return make_disagg_arg_parser(disagg_parser)
def cmd_init() -> list[CLISubcommand]:
return [DisaggSubcommand()]
def make_disagg_arg_parser(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"model_tag",
type=str,
help=
"The model tag to use for the vLLM OpenAI Compatible API zmq server.")
parser.add_argument('--zmq-server-addr',
type=str,
required=True,
help='The address to serve the zmq server on.')
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When ``--max-logprobs`` is specified, represents single tokens "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): ``'name=path'`` "
"Example (new format): "
"``{\"name\": \"name\", \"path\": \"lora_path\", "
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_disagg_args(args: argparse.Namespace):
"""Quick checks for model disagg args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "disagg":
return
# Enable reasoning needs a reasoning parser to be valid
if args.enable_reasoning and not args.reasoning_parser:
raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser")

View File

@ -6,8 +6,6 @@ import signal
import sys
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.connect
import vllm.entrypoints.cli.disagg
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve
import vllm.version
@ -20,8 +18,6 @@ CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.disagg,
vllm.entrypoints.cli.connect,
]

View File

@ -14,7 +14,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
from vllm.entrypoints.openai.protocol import (
CompletionResponse, ErrorResponse)
@ -55,16 +55,22 @@ async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector Start %s %s", args,
uvicorn_kwargs)
# Avoid dropping requests under high concurrency.
set_ulimit()
# IPC Paths.
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
# are not using pickle to avoid RCE security flaw.
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
connector_addr = f"ipc://{args.connector_addr}"
# Start Engine.
with pd_engine_client_ctx_manager(
args.model, prefill_addr, decode_addr, connector_addr) as engine_client:
# Initialize App State.
model_config = await engine_client.get_model_config()
# Models.
app.state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
@ -73,8 +79,6 @@ async def main(args, **uvicorn_kwargs):
model_path=args.model)
],
)
# Completions.
app.state.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine_client,
model_config=model_config,
@ -82,10 +86,10 @@ async def main(args, **uvicorn_kwargs):
request_logger=None,
)
# Init Uvicorn Server. Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
# Run Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(

View File

@ -2,6 +2,7 @@
import asyncio
import msgspec
import os
from collections.abc import AsyncGenerator
from typing import Dict, List, Mapping, Optional
@ -12,9 +13,7 @@ import zmq.asyncio
from vllm import SamplingParams
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
@ -26,38 +25,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket
from vllm.utils import Device, make_zmq_socket
logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
# NOTE FOR DEVELOPERS:
# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE
# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE
# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS.
class PDRequest(msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
prompt_token_ids: List[int]
sampling_params: SamplingParams
# TODO: support multimodal inputs.
class PDResponse(msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
request_id: str
success: bool
delta_text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Optional[str] = None
logprobs = None # TODO
class PDEngine:
"""
PDEngine:
@ -89,6 +62,8 @@ class PDEngine:
self.to_prefill = make_zmq_socket(
self.ctx, f"{prefill_addr}", zmq.constants.PUSH)
self.connector_addr = connector_addr
self.decode_addr = decode_addr
self.prefill_addr = prefill_addr
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
@ -118,7 +93,6 @@ class PDEngine:
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side)
self.tokenizer = TokenizerGroup(**init_kwargs)
def shutdown(self):
if (ctx := self.ctx) is not None:
@ -128,6 +102,14 @@ class PDEngine:
if (task := self.output_handler) is not None:
task.cancel()
ipc_paths = [
self.connector_addr, self.decode_addr, self.prefill_addr
]
for path in ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
async def _run_log_running(self):
logger.info("Running requests: %d", len(self.queues))
await asyncio.sleep(10.)
@ -312,25 +294,3 @@ class PDEngine:
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--connector-addr",
type=str,
required=True,
help="The zmq ipc connector address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View File

@ -1,225 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import os
import msgpack
import signal
import traceback
from argparse import Namespace
from http import HTTPStatus
import uvloop
from typing import Optional
import zmq
import zmq.asyncio
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse, ZmqMsgRequest,
ZmqMsgResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.logger import init_logger
from vllm.utils import set_ulimit
from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger('vllm.entrypoints.openai.zmq_server')
logger = init_logger(__name__)
openai_serving_completion: OpenAIServingCompletion
openai_serving_models: OpenAIServingModels
async def log_stats(running_requests: set[asyncio.Task]):
while True:
logger.info("Running requests: %d", len(running_requests))
await asyncio.sleep(10)
def _cleanup_ipc_path(server_addr: str):
socket_path = server_addr.replace("ipc://", "")
logger.info("cleaning up local IPC socket file %s", socket_path)
if os.path.exists(socket_path):
os.remove(socket_path)
async def serve_zmq(arg) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg,
arg.zmq_server_addr)
# different zmq context can't communicate use inproc
server_addr = f"ipc://{arg.zmq_server_addr}"
async def handle_request(
request: PDRequest,
engine: EngineClient,
socket: zmq.asyncio.Socket,
encoder: msgpack.Encoder,
) -> None:
request_id = request.request_id
try:
# Prepare our context and sockets
context = zmq.asyncio.Context()
socket = context.socket(zmq.ROUTER)
# unlimited HWM
hwm_limit = 0
# 1) Generate RequestOutputs.
async for request_output in engine.generate(
prompt=request.prompt_token_ids,
sampling_params=request.sampling_params,
request_id=request_id):
socket.bind(server_addr)
socket.setsockopt(zmq.SNDHWM, hwm_limit)
socket.setsockopt(zmq.RCVHWM, hwm_limit)
assert len(request_output.outputs) == 0, "Only support N=1 right now."
out = request_output.outputs[0]
running_requests: set[asyncio.Task] = set()
logger.info("zmq Server started at %s", server_addr)
asyncio.create_task(log_stats(running_requests))
# 2) Convert RequestOutput --> PDResponse.
response = PDResponse(
request_id=request_id,
success=True,
text=out.text,
token_ids=out.token_ids,
finish_reasons=out.finish_reason,
stop_reason=out.stop_reason,
)
response_bytes = encoder(response)
# 3) Send to Connector.
await socket.send(response_bytes, copy=False)
except Exception as e:
# TODO: actual error handling.
logger.error("Exception in Worker Routine: %s request_id: %s", e,
request_id)
response = PDResponse(request_id=request_id, success=False)
response_bytes = encoder(response)
await socket.send(response, copy=False)
async def run_server(args, engine: EngineClient):
"""Get Requests and Handle Them."""
running_requests: set[asyncio.Task] = set()
decoder = msgpack.Decoder(PDRequest)
encoder = msgpack.Encoder()
ctx: Optional[zmq.asyncio.Context] = None
try:
# IPC Setup.
ctx = zmq.asyncio.Context()
from_connector = make_zmq_socket(
ctx, f"ipc://{args.server_addr}", zmq.constants.PULL)
to_connector = make_zmq_socket(
ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH)
# Main Loop.
while True:
try:
logger.debug("zmq Server waiting for request")
# get new request from the client
message_parts = await socket.recv_multipart()
logger.debug("received request: %s", message_parts)
logger.debug("received len: %d", len(message_parts))
identity, body = message_parts[0], message_parts[1]
zmq_msg_request = ZmqMsgRequest.model_validate_json(body)
# launch request handler coroutine
task = asyncio.create_task(
worker_routine(identity, zmq_msg_request, socket))
running_requests.add(task)
task.add_done_callback(running_requests.discard)
except zmq.ZMQError as e:
logger.error(traceback.format_exc())
logger.error("ZMQError: %s", e)
break
except Exception as e:
logger.error(traceback.format_exc())
logger.error("Unexpected error: %s", e)
break
# 1) Get request from the Connector.
pd_request_bytes = await from_connector.recv().buffer
pd_request = decoder(pd_request_bytes)
# 2) Launch a coroutine to handle the request.
task = asyncio.create_task(handle_request(
pd_request, engine, to_connector, encoder))
running_requests.add(task)
task.add_done_callback(running_requests.discard)
except KeyboardInterrupt:
logger.info("KeyboardInterrupt received, exiting")
logger.debug("Worker server loop interrupted.")
finally:
# Clean up resources
for task in running_requests:
task.cancel()
await asyncio.gather(*running_requests, return_exceptions=True)
socket.close()
context.destroy(linger=0)
_cleanup_ipc_path(server_addr)
if ctx is not None:
ctx.destroy(linger=0)
async def run_zmq_server(args) -> None:
logger.info("vLLM zmq server version %s", VLLM_VERSION)
async def main(args) -> None:
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
logger.info("args: %s", args)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
# Workaround to avoid footguns where uvicorn drops requests
# with too many concurrent requests active due to ulimit.
set_ulimit()
# Interrupt on sigterm during initialization.
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
async with build_async_engine_client(args) as engine:
await run_server(args, engine)
model_config = await engine_client.get_model_config()
await init_state(engine_client, model_config, args)
logger.info("init_state successful")
await serve_zmq(args)
async def init_state(
engine_client: EngineClient,
model_config: ModelConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
global openai_serving_models
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
await openai_serving_models.init_static_loras()
global openai_serving_completion
openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
openai_serving_models,
request_logger=None,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
"""Worker routine"""
try:
request_id = zmq_msg_request.request_id
logger.debug("receive request: %s from %s, request_id: %s",
zmq_msg_request.model_dump_json(), identity.decode(),
request_id)
if isinstance(zmq_msg_request.body, CompletionRequest):
await create_completion(identity, zmq_msg_request, socket)
else:
logger.error("Error in worker routine: %s request_id: %s",
"unsupported request type", request_id)
raise Exception("unsupported request type")
except Exception as e:
logger.error("Error in worker routine: %s request_id: %s", e,
request_id)
logger.error(traceback.format_exc())
logger.debug("send ErrorResponse %s", str(e))
await socket.send_multipart([
identity,
ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body=json.dumps({
"content":
"unsupported request type",
"status_code":
HTTPStatus.INTERNAL_SERVER_ERROR
})).model_dump_json().encode(),
])
async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest,
socket: zmq.asyncio.Socket):
request: CompletionRequest = zmq_msg_request.body
logger.debug("zmq request post: %s", request.model_dump_json())
generator = await openai_serving_completion.create_completion(request)
logger.debug("zmq request end post")
request_id = zmq_msg_request.request_id
if isinstance(generator, (ErrorResponse, CompletionResponse)):
logger.debug("send response %s", generator.model_dump_json())
if isinstance(generator, ErrorResponse):
body = json.dumps({
"content": generator.model_dump(),
"status_code": generator.code
})
elif isinstance(generator, CompletionResponse):
body = json.dumps({"content": generator.model_dump()})
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body_type="response",
body=body)
await socket.send_multipart(
[identity, zmq_msg_response.model_dump_json().encode()])
else:
async for chunk in generator:
zmq_msg_response = ZmqMsgResponse(request_id=request_id,
type=zmq_msg_request.type,
body=chunk)
if "data: [DONE]" not in chunk:
zmq_msg_response.stop = False
logger.debug("send chunk identity: %s, request_id: %s, chunk: %s",
identity.decode(), request_id, chunk)
await socket.send_multipart(
[identity,
zmq_msg_response.model_dump_json().encode()])
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument('--connector-addr',
type=str,
required=True,
help='The address of the connector.')
parser.add_argument('--worker-addr',
type=str,
required=True,
help='The address of the worker.')
AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
uvloop.run(main(args))