From 85687b43e7e25a50101a92688ae04e024c891f07 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 22 Mar 2025 17:00:46 -0400 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/entrypoints/cli/connect.py | 72 ----- vllm/entrypoints/cli/disagg.py | 115 ------- vllm/entrypoints/cli/main.py | 4 - vllm/entrypoints/disaggregated/api_server.py | 22 +- vllm/entrypoints/disaggregated/pd_engine.py | 66 +--- .../disaggregated/worker_server.py | 283 ++++++------------ 6 files changed, 118 insertions(+), 444 deletions(-) delete mode 100644 vllm/entrypoints/cli/connect.py delete mode 100644 vllm/entrypoints/cli/disagg.py diff --git a/vllm/entrypoints/cli/connect.py b/vllm/entrypoints/cli/connect.py deleted file mode 100644 index 467f6a518d033..0000000000000 --- a/vllm/entrypoints/cli/connect.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import argparse - -import uvloop - -from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.disagg_connector import run_disagg_connector -from vllm.utils import FlexibleArgumentParser - - -class ConnectSubcommand(CLISubcommand): - """The `connect` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "connect" - super().__init__() - - @staticmethod - def cmd(args: argparse.Namespace) -> None: - uvloop.run(run_disagg_connector(args)) - - def validate(self, args: argparse.Namespace) -> None: - validate_connect_parsed_args(args) - - def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - connect_parser = subparsers.add_parser( - "connect", - help= - "Start the vLLM OpenAI Compatible API Server which connect to other" - "servers disaggreate prefill and decode", - usage="vllm connect [options]") - - return make_connect_arg_parser(connect_parser) - - -def cmd_init() -> list[CLISubcommand]: - return [ConnectSubcommand()] - - -def make_connect_arg_parser( - parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument("--port", - type=int, - default=7001, - help="The fastapi server port default 7001") - # support ipc only now, support tcp later(with auth) - parser.add_argument( - "--protocol", - type=str, - choices=["ipc"], - default="ipc", - help="The zmq socket addr protocol IPC (Inter-Process Communication)") - # security concern only support ipc now - parser.add_argument("--prefill-addr", - type=str, - required=True, - help="The zmq ipc prefill address") - parser.add_argument("--decode-addr", - type=str, - required=True, - help="The zmq ipc decode address") - - return parser - - -def validate_connect_parsed_args(args: argparse.Namespace): - """Quick checks for connect args that raise prior to loading.""" - if hasattr(args, "subparser") and args.subparser != "connect": - return diff --git a/vllm/entrypoints/cli/disagg.py b/vllm/entrypoints/cli/disagg.py deleted file mode 100644 index 5d286e8e0b7aa..0000000000000 --- a/vllm/entrypoints/cli/disagg.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import argparse -import asyncio - -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.cli_args import (LoRAParserAction, - PromptAdapterParserAction) -from vllm.entrypoints.openai.zmq_server import run_zmq_server -from vllm.utils import FlexibleArgumentParser - - -class DisaggSubcommand(CLISubcommand): - """The `disagg` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "disagg" - super().__init__() - - @staticmethod - def cmd(args: argparse.Namespace) -> None: - # The default value of `--model` - if not args.model_tag: - raise ValueError( - "With `vllm disagg`, you should provide the model as a " - "positional argument instead of via the `--model` option.") - - # EngineArgs expects the model name to be passed as --model. - args.model = args.model_tag - - asyncio.run(run_zmq_server(args)) - - def validate(self, args: argparse.Namespace) -> None: - validate_parsed_disagg_args(args) - - def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - disagg_parser = subparsers.add_parser( - "disagg", - help="Start the vLLM OpenAI Compatible API zmq server", - usage="vllm disagg [options]") - - return make_disagg_arg_parser(disagg_parser) - - -def cmd_init() -> list[CLISubcommand]: - return [DisaggSubcommand()] - - -def make_disagg_arg_parser( - parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument( - "model_tag", - type=str, - help= - "The model tag to use for the vLLM OpenAI Compatible API zmq server.") - parser.add_argument('--zmq-server-addr', - type=str, - required=True, - help='The address to serve the zmq server on.') - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - help="If specified, will run the OpenAI frontend server in the same " - "process as the model serving engine.") - parser.add_argument( - "--return-tokens-as-token-ids", - action="store_true", - help="When ``--max-logprobs`` is specified, represents single tokens " - " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.") - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - '\n\nDefault: Unlimited') - parser.add_argument( - "--lora-modules", - type=nullable_str, - default=None, - nargs='+', - action=LoRAParserAction, - help="LoRA module configurations in either 'name=path' format" - "or JSON format. " - "Example (old format): ``'name=path'`` " - "Example (new format): " - "``{\"name\": \"name\", \"path\": \"lora_path\", " - "\"base_model_name\": \"id\"}``") - - parser.add_argument( - "--prompt-adapters", - type=nullable_str, - default=None, - nargs='+', - action=PromptAdapterParserAction, - help="Prompt adapter configurations in the format name=path. " - "Multiple adapters can be specified.") - - AsyncEngineArgs.add_cli_args(parser) - - return parser - - -def validate_parsed_disagg_args(args: argparse.Namespace): - """Quick checks for model disagg args that raise prior to loading.""" - if hasattr(args, "subparser") and args.subparser != "disagg": - return - - # Enable reasoning needs a reasoning parser to be valid - if args.enable_reasoning and not args.reasoning_parser: - raise TypeError("Error: --enable-reasoning requires " - "--reasoning-parser") diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index e9edbf663ee8e..13f2761b0db06 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -6,8 +6,6 @@ import signal import sys import vllm.entrypoints.cli.benchmark.main -import vllm.entrypoints.cli.connect -import vllm.entrypoints.cli.disagg import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.serve import vllm.version @@ -20,8 +18,6 @@ CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, vllm.entrypoints.cli.benchmark.main, - vllm.entrypoints.cli.disagg, - vllm.entrypoints.cli.connect, ] diff --git a/vllm/entrypoints/disaggregated/api_server.py b/vllm/entrypoints/disaggregated/api_server.py index 95c5e5df79a1d..2b18e594163f3 100644 --- a/vllm/entrypoints/disaggregated/api_server.py +++ b/vllm/entrypoints/disaggregated/api_server.py @@ -14,7 +14,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket from vllm.entrypoints.openai.protocol import ( CompletionResponse, ErrorResponse) @@ -55,16 +55,22 @@ async def main(args, **uvicorn_kwargs): logger.info("vLLM Disaggregate Connector Start %s %s", args, uvicorn_kwargs) + # Avoid dropping requests under high concurrency. + set_ulimit() + + # IPC Paths. + # NOTE FOR DEVELOPERS: when shifting to TCP, ensure you + # are not using pickle to avoid RCE security flaw. prefill_addr = f"ipc://{args.prefill_addr}" decode_addr = f"ipc://{args.decode_addr}" connector_addr = f"ipc://{args.connector_addr}" + # Start Engine. with pd_engine_client_ctx_manager( args.model, prefill_addr, decode_addr, connector_addr) as engine_client: + # Initialize App State. model_config = await engine_client.get_model_config() - - # Models. app.state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, model_config=model_config, @@ -73,8 +79,6 @@ async def main(args, **uvicorn_kwargs): model_path=args.model) ], ) - - # Completions. app.state.openai_serving_completion = OpenAIServingCompletion( engine_client=engine_client, model_config=model_config, @@ -82,10 +86,10 @@ async def main(args, **uvicorn_kwargs): request_logger=None, ) - # Init Uvicorn Server. Server. - config = uvicorn.Config(app, host="0.0.0.0", port=args.port) - server = uvicorn.Server(config) - await server.serve() + # Run Server. + config = uvicorn.Config(app, host="0.0.0.0", port=args.port) + server = uvicorn.Server(config) + await server.serve() if __name__ == "__main__": parser = FlexibleArgumentParser( diff --git a/vllm/entrypoints/disaggregated/pd_engine.py b/vllm/entrypoints/disaggregated/pd_engine.py index 0efed4e3b973b..fc2d7177bf99b 100644 --- a/vllm/entrypoints/disaggregated/pd_engine.py +++ b/vllm/entrypoints/disaggregated/pd_engine.py @@ -2,6 +2,7 @@ import asyncio import msgspec +import os from collections.abc import AsyncGenerator from typing import Dict, List, Mapping, Optional @@ -12,9 +13,7 @@ import zmq.asyncio from vllm import SamplingParams from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.entrypoints.openai.api_server import run_server -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse from vllm.inputs.data import PromptType from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -26,38 +25,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket +from vllm.utils import Device, make_zmq_socket logger = init_logger(__name__) DEFAULT_MAX_TOKENS = 32000 -# NOTE FOR DEVELOPERS: -# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE -# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE -# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS. - -class PDRequest(msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - request_id: str - prompt_token_ids: List[int] - sampling_params: SamplingParams - # TODO: support multimodal inputs. - -class PDResponse(msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - request_id: str - success: bool - delta_text: Optional[str] = None - finish_reason: Optional[str] = None - stop_reason: Optional[str] = None - logprobs = None # TODO - - class PDEngine: """ PDEngine: @@ -89,6 +62,8 @@ class PDEngine: self.to_prefill = make_zmq_socket( self.ctx, f"{prefill_addr}", zmq.constants.PUSH) self.connector_addr = connector_addr + self.decode_addr = decode_addr + self.prefill_addr = prefill_addr # Background loops (started on first generate()). self.output_handler: Optional[asyncio.Task] = None @@ -118,7 +93,6 @@ class PDEngine: revision=self.model_config.tokenizer_revision, truncation_side=self.model_config.truncation_side) self.tokenizer = TokenizerGroup(**init_kwargs) - def shutdown(self): if (ctx := self.ctx) is not None: @@ -128,6 +102,14 @@ class PDEngine: if (task := self.output_handler) is not None: task.cancel() + ipc_paths = [ + self.connector_addr, self.decode_addr, self.prefill_addr + ] + for path in ipc_paths: + socket_path = path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + async def _run_log_running(self): logger.info("Running requests: %d", len(self.queues)) await asyncio.sleep(10.) @@ -312,25 +294,3 @@ class PDEngine: async def add_lora(self, lora_request: LoRARequest) -> None: raise NotImplementedError - - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--connector-addr", - type=str, - required=True, - help="The zmq ipc connector address") - parser.add_argument("--prefill-addr", - type=str, - required=True, - help="The zmq ipc prefill address") - parser.add_argument("--decode-addr", - type=str, - required=True, - help="The zmq ipc decode address") - parser = make_arg_parser(parser) - args = parser.parse_args() - validate_parsed_serve_args(args) - - uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/disaggregated/worker_server.py b/vllm/entrypoints/disaggregated/worker_server.py index caba77e3a7462..f403c9dbd256d 100644 --- a/vllm/entrypoints/disaggregated/worker_server.py +++ b/vllm/entrypoints/disaggregated/worker_server.py @@ -1,225 +1,126 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import json -import os +import msgpack import signal -import traceback -from argparse import Namespace -from http import HTTPStatus +import uvloop +from typing import Optional import zmq import zmq.asyncio -from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncEngineArgs from vllm.engine.protocol import EngineClient +from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.protocol import (CompletionRequest, - CompletionResponse, - ErrorResponse, ZmqMsgRequest, - ZmqMsgResponse) -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) from vllm.logger import init_logger -from vllm.utils import set_ulimit +from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket from vllm.version import __version__ as VLLM_VERSION -logger = init_logger('vllm.entrypoints.openai.zmq_server') +logger = init_logger(__name__) -openai_serving_completion: OpenAIServingCompletion -openai_serving_models: OpenAIServingModels - - -async def log_stats(running_requests: set[asyncio.Task]): - while True: - logger.info("Running requests: %d", len(running_requests)) - await asyncio.sleep(10) - - -def _cleanup_ipc_path(server_addr: str): - socket_path = server_addr.replace("ipc://", "") - logger.info("cleaning up local IPC socket file %s", socket_path) - if os.path.exists(socket_path): - os.remove(socket_path) - - -async def serve_zmq(arg) -> None: - """Server routine""" - logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg, - arg.zmq_server_addr) - # different zmq context can't communicate use inproc - server_addr = f"ipc://{arg.zmq_server_addr}" +async def handle_request( + request: PDRequest, + engine: EngineClient, + socket: zmq.asyncio.Socket, + encoder: msgpack.Encoder, +) -> None: + request_id = request.request_id try: - # Prepare our context and sockets - context = zmq.asyncio.Context() - socket = context.socket(zmq.ROUTER) - # unlimited HWM - hwm_limit = 0 + # 1) Generate RequestOutputs. + async for request_output in engine.generate( + prompt=request.prompt_token_ids, + sampling_params=request.sampling_params, + request_id=request_id): - socket.bind(server_addr) - socket.setsockopt(zmq.SNDHWM, hwm_limit) - socket.setsockopt(zmq.RCVHWM, hwm_limit) + assert len(request_output.outputs) == 0, "Only support N=1 right now." + out = request_output.outputs[0] - running_requests: set[asyncio.Task] = set() - logger.info("zmq Server started at %s", server_addr) - asyncio.create_task(log_stats(running_requests)) + # 2) Convert RequestOutput --> PDResponse. + response = PDResponse( + request_id=request_id, + success=True, + text=out.text, + token_ids=out.token_ids, + finish_reasons=out.finish_reason, + stop_reason=out.stop_reason, + ) + response_bytes = encoder(response) + # 3) Send to Connector. + await socket.send(response_bytes, copy=False) + + except Exception as e: + # TODO: actual error handling. + logger.error("Exception in Worker Routine: %s request_id: %s", e, + request_id) + response = PDResponse(request_id=request_id, success=False) + response_bytes = encoder(response) + await socket.send(response, copy=False) + +async def run_server(args, engine: EngineClient): + """Get Requests and Handle Them.""" + running_requests: set[asyncio.Task] = set() + decoder = msgpack.Decoder(PDRequest) + encoder = msgpack.Encoder() + + ctx: Optional[zmq.asyncio.Context] = None + try: + # IPC Setup. + ctx = zmq.asyncio.Context() + from_connector = make_zmq_socket( + ctx, f"ipc://{args.server_addr}", zmq.constants.PULL) + to_connector = make_zmq_socket( + ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH) + + # Main Loop. while True: - try: - logger.debug("zmq Server waiting for request") - # get new request from the client - message_parts = await socket.recv_multipart() - logger.debug("received request: %s", message_parts) - logger.debug("received len: %d", len(message_parts)) - identity, body = message_parts[0], message_parts[1] - zmq_msg_request = ZmqMsgRequest.model_validate_json(body) - # launch request handler coroutine - task = asyncio.create_task( - worker_routine(identity, zmq_msg_request, socket)) - running_requests.add(task) - task.add_done_callback(running_requests.discard) - except zmq.ZMQError as e: - logger.error(traceback.format_exc()) - logger.error("ZMQError: %s", e) - break - except Exception as e: - logger.error(traceback.format_exc()) - logger.error("Unexpected error: %s", e) - break + # 1) Get request from the Connector. + pd_request_bytes = await from_connector.recv().buffer + pd_request = decoder(pd_request_bytes) + + # 2) Launch a coroutine to handle the request. + task = asyncio.create_task(handle_request( + pd_request, engine, to_connector, encoder)) + running_requests.add(task) + task.add_done_callback(running_requests.discard) + except KeyboardInterrupt: - logger.info("KeyboardInterrupt received, exiting") + logger.debug("Worker server loop interrupted.") + finally: - # Clean up resources for task in running_requests: task.cancel() - await asyncio.gather(*running_requests, return_exceptions=True) - socket.close() - context.destroy(linger=0) - _cleanup_ipc_path(server_addr) + if ctx is not None: + ctx.destroy(linger=0) -async def run_zmq_server(args) -> None: - logger.info("vLLM zmq server version %s", VLLM_VERSION) +async def main(args) -> None: + logger.info("vLLM P/D Worker Server %s", VLLM_VERSION) logger.info("args: %s", args) - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active + # Workaround to avoid footguns where uvicorn drops requests + # with too many concurrent requests active due to ulimit. set_ulimit() + # Interrupt on sigterm during initialization. def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing raise KeyboardInterrupt("terminated") - signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as engine_client: + async with build_async_engine_client(args) as engine: + await run_server(args, engine) - model_config = await engine_client.get_model_config() - await init_state(engine_client, model_config, args) - logger.info("init_state successful") - await serve_zmq(args) - - -async def init_state( - engine_client: EngineClient, - model_config: ModelConfig, - args: Namespace, -) -> None: - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names - ] - - global openai_serving_models - openai_serving_models = OpenAIServingModels( - engine_client=engine_client, - model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, - ) - await openai_serving_models.init_static_loras() - - global openai_serving_completion - openai_serving_completion = OpenAIServingCompletion( - engine_client, - model_config, - openai_serving_models, - request_logger=None, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) - - -async def worker_routine(identity: bytes, zmq_msg_request: ZmqMsgRequest, - socket: zmq.asyncio.Socket): - """Worker routine""" - try: - request_id = zmq_msg_request.request_id - logger.debug("receive request: %s from %s, request_id: %s", - zmq_msg_request.model_dump_json(), identity.decode(), - request_id) - if isinstance(zmq_msg_request.body, CompletionRequest): - await create_completion(identity, zmq_msg_request, socket) - else: - logger.error("Error in worker routine: %s request_id: %s", - "unsupported request type", request_id) - raise Exception("unsupported request type") - - except Exception as e: - logger.error("Error in worker routine: %s request_id: %s", e, - request_id) - logger.error(traceback.format_exc()) - logger.debug("send ErrorResponse %s", str(e)) - await socket.send_multipart([ - identity, - ZmqMsgResponse(request_id=request_id, - type=zmq_msg_request.type, - body=json.dumps({ - "content": - "unsupported request type", - "status_code": - HTTPStatus.INTERNAL_SERVER_ERROR - })).model_dump_json().encode(), - ]) - - -async def create_completion(identity: bytes, zmq_msg_request: ZmqMsgRequest, - socket: zmq.asyncio.Socket): - request: CompletionRequest = zmq_msg_request.body - logger.debug("zmq request post: %s", request.model_dump_json()) - generator = await openai_serving_completion.create_completion(request) - logger.debug("zmq request end post") - request_id = zmq_msg_request.request_id - if isinstance(generator, (ErrorResponse, CompletionResponse)): - logger.debug("send response %s", generator.model_dump_json()) - if isinstance(generator, ErrorResponse): - body = json.dumps({ - "content": generator.model_dump(), - "status_code": generator.code - }) - elif isinstance(generator, CompletionResponse): - body = json.dumps({"content": generator.model_dump()}) - zmq_msg_response = ZmqMsgResponse(request_id=request_id, - type=zmq_msg_request.type, - body_type="response", - body=body) - await socket.send_multipart( - [identity, zmq_msg_response.model_dump_json().encode()]) - else: - async for chunk in generator: - zmq_msg_response = ZmqMsgResponse(request_id=request_id, - type=zmq_msg_request.type, - body=chunk) - if "data: [DONE]" not in chunk: - zmq_msg_response.stop = False - logger.debug("send chunk identity: %s, request_id: %s, chunk: %s", - identity.decode(), request_id, chunk) - await socket.send_multipart( - [identity, - zmq_msg_response.model_dump_json().encode()]) +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument('--connector-addr', + type=str, + required=True, + help='The address of the connector.') + parser.add_argument('--worker-addr', + type=str, + required=True, + help='The address of the worker.') + AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + uvloop.run(main(args))