From a8a621e4196e709b19cdbea7f2afeea1baac2eb7 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 22 Mar 2025 13:11:50 -0400 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/entrypoints/disagg_connector.py | 283 ------------------ vllm/entrypoints/disaggregated/__init__.py | 0 vllm/entrypoints/disaggregated/connector.py | 305 ++++++++++++++++++++ 3 files changed, 305 insertions(+), 283 deletions(-) delete mode 100644 vllm/entrypoints/disagg_connector.py create mode 100644 vllm/entrypoints/disaggregated/__init__.py create mode 100644 vllm/entrypoints/disaggregated/connector.py diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py deleted file mode 100644 index e7e7e55627b99..0000000000000 --- a/vllm/entrypoints/disagg_connector.py +++ /dev/null @@ -1,283 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import signal -import sys -import traceback -import uuid -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from http import HTTPStatus -from typing import Union - -import uvicorn -import uvloop -import zmq -import zmq.asyncio -from fastapi import BackgroundTasks, FastAPI, Request -from fastapi.responses import JSONResponse, StreamingResponse - -from vllm.entrypoints.openai.protocol import (CompletionRequest, ZmqMsgRequest, - ZmqMsgResponse) -from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.disagg_connector') - -TIME_OUT = 5 -X_REQUEST_ID_KEY = "X-Request-Id" -CONTENT_TYPE_STREAM = "text/event-stream" - -# communication between output handlers and execute_task_async -request_queues: dict[str, asyncio.Queue] - - -async def log_stats(request_queues: dict[str, asyncio.Queue]): - while True: - logger.info("Running requests: %d", len(request_queues)) - await asyncio.sleep(10) - - -# create async socket use ZMQ_DEALER -async def create_socket(url: str, - zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket: - socket = zmqctx.socket(zmq.DEALER) - identity = f"connector-{uuid.uuid4()}" - # unlimited HWM - hwm_limit = 0 - socket.setsockopt(zmq.IDENTITY, identity.encode()) - socket.setsockopt(zmq.SNDHWM, hwm_limit) - socket.setsockopt(zmq.RCVHWM, hwm_limit) - socket.connect(url) - logger.info("%s started at %s", identity, url) - return socket - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # create socket pool with prefill and decode - logger.info("start connect zmq server") - app.state.zmqctx = zmq.asyncio.Context() - app.state.prefill_socket = await create_socket(app.state.prefill_addr, - zmqctx=app.state.zmqctx) - logger.info("success create_socke sockets_prefill") - app.state.decode_socket = await create_socket(app.state.decode_addr, - zmqctx=app.state.zmqctx) - logger.info("success create_socket sockets_decode") - global request_queues - request_queues = {} - asyncio.create_task(prefill_handler(app.state.prefill_socket)) - asyncio.create_task(decode_handler(app.state.decode_socket)) - asyncio.create_task(log_stats(request_queues)) - yield - ## close zmq context - logger.info("shutdown disagg connector") - logger.info("term zmqctx") - app.state.zmqctx.destroy(linger=0) - - -app = FastAPI(lifespan=lifespan) - - -@app.post('/v1/completions') -async def completions(request: CompletionRequest, raw_request: Request, - background_tasks: BackgroundTasks): - try: - # Add the X-Request-Id header to the raw headers list - header = dict(raw_request.headers) - request_id = header.get(X_REQUEST_ID_KEY) - queue: asyncio.Queue[ZmqMsgResponse] = asyncio.Queue() - if request_id is None: - request_id = str(uuid.uuid4()) - logger.debug("add X-Request-Id: %s", request_id) - logger.debug("X-Request-Id is: %s", request_id) - request_queues[request_id] = queue - zmq_msg_request = ZmqMsgRequest(request_id=request_id, - type="completions", - body=request) - logger.info("Received request_id: %s, request: %s, header: %s", - request_id, zmq_msg_request.model_dump_json(), header) - original_max_tokens = request.max_tokens - # change max_tokens = 1 to let it only do prefill - request.max_tokens = 1 - # finish prefill - try: - prefill_response = await prefill(zmq_msg_request) - if isinstance(prefill_response, JSONResponse - ) and prefill_response.status_code != HTTPStatus.OK: - return prefill_response - logger.debug("finish prefill start decode") - request.max_tokens = original_max_tokens - response = await decode(zmq_msg_request) - logger.debug("finish decode") - except Exception as e: - logger.error("Error occurred in disagg prefill proxy server, %s", - e) - response = JSONResponse( - {"error": { - "message": str(e) - }}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - return response - - except Exception as e: - exc_info = sys.exc_info() - logger.error("Error occurred in disagg prefill proxy server") - logger.error(e) - logger.error("".join(traceback.format_exception(*exc_info))) - response = JSONResponse({"error": { - "message": str(e) - }}, HTTPStatus.INTERNAL_SERVER_ERROR) - return response - finally: - if request_id is not None: - background_tasks.add_task(cleanup_request_id, request_id) - - -async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str): - while True: - try: - [body] = await socket.recv_multipart() - response = ZmqMsgResponse.model_validate_json(body) - request_id = response.request_id - logger.debug("%s socket received result: %s", scene, - response.model_dump_json()) - if request_id in request_queues: - request_queues[request_id].put_nowait(response) - else: - logger.debug( - "%s socket received but request_id not found discard: %s", - scene, request_id) - except Exception as e: - logger.error(traceback.format_exc()) - logger.error("%s handler error: %s", scene, e) - - -# prefill handler -async def prefill_handler(prefill_socket: zmq.asyncio.Socket): - await socket_recv_handler(prefill_socket, "prefill") - - -# decode handler -async def decode_handler(decode_socket: zmq.asyncio.Socket): - await socket_recv_handler(decode_socket, "decode") - - -# select a socket and execute task -async def execute_task_async(zmq_msg_request: ZmqMsgRequest, - socket: zmq.asyncio.Socket): - try: - request_id = zmq_msg_request.request_id - requestBody = zmq_msg_request.model_dump_json() - logger.debug("Sending requestBody: %s", requestBody) - socket.send_multipart([requestBody.encode()]) - logger.debug("Sent end") - queue = request_queues[request_id] - while True: - logger.debug("Waiting for reply") - zmq_msg_response: ZmqMsgResponse = await asyncio.wait_for( - queue.get(), TIME_OUT) - logger.debug("Received result: %s", - zmq_msg_response.model_dump_json()) - yield zmq_msg_response - if zmq_msg_response.stop: - logger.debug("Received stop: %s", zmq_msg_response.stop) - break - except asyncio.TimeoutError: - logger.error(traceback.format_exc()) - yield JSONResponse("timeout", HTTPStatus.REQUEST_TIMEOUT) - finally: - logger.debug("request_id: %s, execute_task_async end", request_id) - - -async def prefill(zmq_msg_request: ZmqMsgRequest) -> Union[JSONResponse, bool]: - logger.debug("start prefill") - generator = execute_task_async(zmq_msg_request, app.state.prefill_socket) - async for res in generator: - logger.debug("res: %s", res) - if res.body_type == "response": - return JSONResponse(res.body) - return True - - -async def generate_stream_response( - fisrt_reply: str, - generator: AsyncGenerator[ZmqMsgResponse]) -> AsyncGenerator[str]: - yield fisrt_reply - async for reply in generator: - yield reply.body - - -async def decode( - zmq_msg_request: ZmqMsgRequest -) -> Union[JSONResponse, StreamingResponse]: - logger.debug("start decode") - generator = execute_task_async(zmq_msg_request, app.state.decode_socket) - - async for res in generator: - logger.debug("res: %s", res) - if res.body_type == "response": - return JSONResponse(res.body) - else: - return StreamingResponse(generate_stream_response( - res.body, generator), - media_type=CONTENT_TYPE_STREAM) - - # If the generator is empty, return a default error response - logger.error("No response received from generator") - return JSONResponse({"error": "No response received from generator"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - - -def cleanup_request_id(request_id: str): - if request_id in request_queues: - logger.info("del request_id: %s, decode finished", request_id) - del request_queues[request_id] - - -async def run_disagg_connector(args, **uvicorn_kwargs): - logger.info("vLLM Disaggregate Connector start %s %s", args, - uvicorn_kwargs) - logger.info(args.prefill_addr) - app.state.port = args.port - app.state.prefill_addr = f"ipc://{args.prefill_addr}" - app.state.decode_addr = f"ipc://{args.decode_addr}" - logger.info( - "start connect prefill_addr: %s decode_addr: %s " - "zmq server fastapi port: %s", app.state.prefill_addr, - app.state.decode_addr, app.state.port) - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("terminated") - - signal.signal(signal.SIGTERM, signal_handler) - # init uvicorn server - config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port) - server = uvicorn.Server(config) - await server.serve() - - -if __name__ == "__main__": - # NOTE(simon): - # This section should be sync with vllm/entrypoints/cli/connect.py for CLI - # entrypoints. - parser = FlexibleArgumentParser(description="vLLM disagg connect server.") - parser.add_argument("--port", - type=int, - default=8001, - help="The fastapi server port default 8001") - # security concern only support ipc now - parser.add_argument("--prefill-addr", - type=str, - required=True, - help="The zmq ipc prefill address") - parser.add_argument("--decode-addr", - type=str, - required=True, - help="The zmq ipc decode address") - - args = parser.parse_args() - - uvloop.run(run_disagg_connector(args)) diff --git a/vllm/entrypoints/disaggregated/__init__.py b/vllm/entrypoints/disaggregated/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/disaggregated/connector.py b/vllm/entrypoints/disaggregated/connector.py new file mode 100644 index 0000000000000..ce2e3311eec19 --- /dev/null +++ b/vllm/entrypoints/disaggregated/connector.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import msgspec +from collections.abc import AsyncGenerator +from typing import Dict, Mapping, Optional + +import uvicorn +import uvloop +import zmq +import zmq.asyncio + +from vllm import SamplingParams +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.inputs.data import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Device, FlexibleArgumentParser, make_zmq_socket + +logger = init_logger(__name__) + +DEFAULT_MAX_TOKENS = 32000 + +# NOTE FOR DEVELOPERS: +# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE +# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE +# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS. + +class PDRequest(msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + request_id: str + prompt: str + sampling_params: SamplingParams + # TODO: support multimodal inputs. + +class PDResponse(msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + request_id: str + success: bool + delta_text: Optional[str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[str] = None + logprobs = None # TODO + + +class PDEngine: + """ + PDEngine: + Equiavlent of AsyncLLM for P/D. Assumes there is + a Prefill and Decode service already running. + + * TODO: actually handle errors and failure. + * TODO: support more than just text input. + * TODO: move under vllm/v1/engine one past prototype. + """ + def __init__(self, prefill_addr: str, decode_addr: str, connector_addr: str): + # Request queues. + self.queues: Dict[str, asyncio.Queue] = {} + + # Serialization encoder. + self.encoder = msgspec.msgpack.Encoder() + + # ZMQ communication.. + self.ctx = zmq.asyncio.Context() + self.to_decode = make_zmq_socket( + self.ctx, f"{decode_addr}", zmq.constants.PUSH) + self.to_prefill = make_zmq_socket( + self.ctx, f"{prefill_addr}", zmq.constants.PUSH) + self.connector_addr = connector_addr + + # Background loops (started on first generate()). + self.output_handler: Optional[asyncio.Task] = None + self.log_running: Optional[asyncio.Task] = None + + def shutdown(self): + if (ctx := self.ctx) is not None: + ctx.destroy(linger=0) + if (task := self.log_running) is not None: + task.cancel() + if (task := self.output_handler) is not None: + task.cancel() + + async def _run_log_running(self): + logger.info("Running requests: %d", len(self.queues)) + await asyncio.sleep(10.) + + async def _run_output_handler(self, socket: zmq.asyncio.Socket): + """ + Pull responses from Decode + Prefill engines and + distribute back to the generate() tasks. + """ + decoder = msgspec.msgpack.Decoder(PDResponse) + + socket: Optional[zmq.asyncio.Socket] = None + try: + socket = make_zmq_socket( + self.ctx, self.connector_addr, zmq.constants.PULL) + + while True: + reponse_bytes = await socket.recv().buffer + response = decoder.decode(reponse_bytes) + self.queues[response.request_id].put_nowait(response) + except: + # TODO: actually handle failure and shutdown. + raise + finally: + if socket is not None: + socket.close(linger=0) + + async def _prefill(self, + request: PDRequest, + q: asyncio.Queue[PDResponse]) -> PDResponse: + # Send request to the prefill instance. + req_bytes = self.encoder(request) + await self.to_prefill.send(req_bytes, copy=False) + + # Wait for the prefill to be done. + response = await q.get() + assert response.request_id == request.request_id + if not response.success: + # TODO: actual error handling and shutdown. + raise Exception("Failed Prefill Request.") + + return response + + async def _decode(self, + request: PDRequest, + q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]: + # Send request to the decode instance. + req_bytes = self.encoder(request) + await self.to_decode.send(req_bytes, copy=False) + + # Iterate response queue and yield each response to caller.. + finished = False + while not finished: + response = await q.get() + if not response.success: + # TODO: actual error handling and shutdown. + raise Exception("Failed Decode Request.") + finished = response.finish_reason is not None + yield response + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[PDResponse]: + # Start loops on first request. + if self.output_handler is None: + self.output_handler = asyncio.create_task(self._run_output_handler()) + self.log_running = asyncio.create_task(self._run_log_running()) + + # TODO: expand to suppo + if not isinstance(prompt, str): + raise ValueError("We currently only support text inputs!") + if request_id in self.queues: + raise ValueError(f"Found duplicate request_id: {request_id}!") + + # Queue to gather output from output_handler. + q: asyncio.Queue[PDResponse] = asyncio.Queue() + self.queues[request_id] = q + + # (1) Perform the prefill (max_tokens=1). + original_max_tokens = sampling_params.max_tokens + request = PDRequest(request_id, prompt, sampling_params) + request.sampling_params.max_tokens = 1 + response = await self._prefill(request, q) + yield response + + # (2) Perform the decodes (original tokens). + request.sampling_params.max_tokens = original_max_tokens + async for response in self._decode(request, q): + yield response + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + raise NotImplementedError + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + raise NotImplementedError + + async def abort(self, request_id: str) -> None: + raise NotImplementedError + + async def get_model_config(self) -> ModelConfig: + raise NotImplementedError + + async def get_decoding_config(self) -> DecodingConfig: + raise NotImplementedError + + async def get_input_preprocessor(self) -> InputPreprocessor: + raise NotImplementedError + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + raise NotImplementedError + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + pass + + async def start_profile(self) -> None: + raise NotImplementedError + + async def stop_profile(self) -> None: + raise NotImplementedError + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + raise NotImplementedError + + async def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up(self) -> None: + raise NotImplementedError + + async def is_sleeping(self) -> bool: + False + + async def add_lora(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + + +async def run_disagg_connector(args, **uvicorn_kwargs): + logger.info("vLLM Connector Start: %s %s", args, uvicorn_kwargs) + + # NOTE FOR DEVELOPERS: when we shift this to TCP, we must + # ensure that the serialization is not pickle based to + # avoid RCE issues from untrusted users!!! + app.state.port = args.port + app.state.connector_addr = f"ipc://{args.connector_addr}" + app.state.decode_addr = f"ipc://{args.decode_addr}" + app.state.prefill_addr = f"ipc://{args.prefill_addr}" + + # init uvicorn server + config = uvicorn.Config(app, host=args.post, port=args.port) + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser.add_argument("--connector-addr", + type=str, + required=True, + help="The zmq ipc connector address") + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The zmq ipc prefill address") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The zmq ipc decode address") + parser = make_arg_parser(parser) + + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args))