From 2fec6e0b5c8aa528dad7a5c0c456b4f3323fd2b4 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 22 Mar 2025 18:45:00 -0400 Subject: [PATCH] working? Signed-off-by: Robert Shaw --- vllm/entrypoints/disaggregated/connector.py | 131 ++++++++ vllm/entrypoints/disaggregated/engine.py | 335 ++++++++++++++++++++ vllm/entrypoints/disaggregated/types.py | 31 ++ vllm/entrypoints/disaggregated/worker.py | 136 ++++++++ 4 files changed, 633 insertions(+) create mode 100644 vllm/entrypoints/disaggregated/connector.py create mode 100644 vllm/entrypoints/disaggregated/engine.py create mode 100644 vllm/entrypoints/disaggregated/types.py create mode 100644 vllm/entrypoints/disaggregated/worker.py diff --git a/vllm/entrypoints/disaggregated/connector.py b/vllm/entrypoints/disaggregated/connector.py new file mode 100644 index 0000000000000..ae1bb21a08b41 --- /dev/null +++ b/vllm/entrypoints/disaggregated/connector.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 + +import uvicorn +import uvloop + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.entrypoints.disaggregated.engine import PDEngine +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket +from vllm.entrypoints.openai.protocol import ( + CompletionResponse, ErrorResponse) + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.disaggregated.api_server') + +app = FastAPI() + +@app.get("/v1/models") +async def show_available_models(raw_request: Request): + handler: OpenAIServingModels = raw_request.app.state.openai_serving_models + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + +@asynccontextmanager +async def pd_engine_client_ctx_manager( + prefill_addr: str, + decode_addr: str, + connector_addr: str, + model_name: str +) -> AsyncIterator[PDEngine]: + engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name) + yield engine + engine.shutdown() + +async def main(args, **uvicorn_kwargs): + logger.info("vLLM Disaggregate Connector Start %s %s", args, + uvicorn_kwargs) + + # Avoid dropping requests under high concurrency. + set_ulimit() + + # IPC Paths. + # NOTE FOR DEVELOPERS: when shifting to TCP, ensure you + # are not using pickle to avoid RCE security flaw. + prefill_addr = f"ipc://{args.prefill_addr}" + decode_addr = f"ipc://{args.decode_addr}" + connector_addr = f"ipc://{args.connector_addr}" + + # Start Engine. + async with pd_engine_client_ctx_manager( + prefill_addr=prefill_addr, + decode_addr=decode_addr, + connector_addr=connector_addr, + model_name=args.model + ) as engine_client: + + # Initialize App State. + model_config = await engine_client.get_model_config() + app.state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=[BaseModelPath( + name=args.served_model_name or args.model, + model_path=args.model) + ], + ) + app.state.openai_serving_completion = OpenAIServingCompletion( + engine_client=engine_client, + model_config=model_config, + models=app.state.openai_serving_models, + request_logger=None, + ) + + # Run Server. + config = uvicorn.Config(app, host="0.0.0.0", port=args.port) + server = uvicorn.Server(config) + await server.serve() + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible P/D Server.") + parser.add_argument("--host", + type=str, + default="0.0.0.0", + help="The host of the HTTP server.") + parser.add_argument("--port", + type=int, + default=8001, + help="The port of the HTTP server.") + parser.add_argument("--model", + type=str, + required=True, + help="The path to the model.") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The served name of the model.") + parser.add_argument("--connector-addr", + type=str, + required=True, + help="The zmq ipc connector address") + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The zmq ipc prefill address") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The zmq ipc decode address") + args = parser.parse_args() + uvloop.run(main(args)) diff --git a/vllm/entrypoints/disaggregated/engine.py b/vllm/entrypoints/disaggregated/engine.py new file mode 100644 index 0000000000000..bef3517488509 --- /dev/null +++ b/vllm/entrypoints/disaggregated/engine.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import msgspec +import os +from collections.abc import AsyncGenerator +from typing import Dict, List, Mapping, Optional + +import zmq +import zmq.asyncio + +from vllm import SamplingParams +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse +from vllm.inputs.data import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import (PoolingRequestOutput, RequestOutput, + CompletionOutput) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import Device + +logger = init_logger(__name__) + +DEFAULT_MAX_TOKENS = 32000 + +class PDEngine: + """ + PDEngine: + Equiavlent of AsyncLLM for P/D. Assumes there is + a Prefill and Decode service already running. + + * TODO: actually handle errors and failure. + * TODO: support more than just text input. + * TODO: move under vllm/v1/engine one past prototype. + """ + + def __init__( + self, + prefill_addr: str, + decode_addr: str, + connector_addr: str, + model_name: str + ): + # Request queues. + self.queues: Dict[str, asyncio.Queue] = {} + + # Serialization encoder. + self.encoder = msgspec.msgpack.Encoder() + + # ZMQ communication. + self.ctx = zmq.asyncio.Context() + self.to_decode = self.ctx.socket(zmq.constants.PUSH) + self.to_decode.bind(f"{decode_addr}") + self.to_prefill = self.ctx.socket(zmq.constants.PUSH) + self.to_prefill.bind(f"{prefill_addr}") + self.connector_addr = connector_addr + self.decode_addr = decode_addr + self.prefill_addr = prefill_addr + + # Background loops (started on first generate()). + self.output_handler: Optional[asyncio.Task] = None + self.log_running: Optional[asyncio.Task] = None + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor EngineClient to avoid needing this. + self.model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="auto", + task="generate", + seed=42 + ) + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor EngineClient to avoid needing this. + init_kwargs = dict( + tokenizer_id=self.model_config.tokenizer, + enable_lora=False, + max_num_seqs=1024, + max_loras=0, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision, + truncation_side=self.model_config.truncation_side) + self.tokenizer = TokenizerGroup(**init_kwargs) + + def shutdown(self): + if (ctx := self.ctx) is not None: + ctx.destroy(linger=0) + if (task := self.log_running) is not None: + task.cancel() + if (task := self.output_handler) is not None: + task.cancel() + + ipc_paths = [ + self.connector_addr, self.decode_addr, self.prefill_addr + ] + for path in ipc_paths: + socket_path = path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + async def _run_log_running(self): + logger.info("Running requests: %d", len(self.queues)) + await asyncio.sleep(10.) + + async def _run_output_handler(self): + """ + Pull responses from Decode + Prefill engines and + distribute back to the generate() tasks. + """ + decoder = msgspec.msgpack.Decoder(PDResponse) + + socket: Optional[zmq.asyncio.Socket] = None + try: + socket = self.ctx.socket(zmq.constants.PULL) + socket.bind(self.connector_addr) + + while True: + reponse_bytes = await socket.recv() + response = decoder.decode(reponse_bytes) + logger.debug("Got Response: %s", response.request_id) + self.queues[response.request_id].put_nowait(response) + except: + # TODO: actually handle failure and shutdown. + raise + finally: + if socket is not None: + socket.close(linger=0) + + async def _prefill( + self, + request: PDRequest, + q: asyncio.Queue[PDResponse], + ): + # Send request to the prefill instance. + req_bytes = self.encoder.encode(request) + await self.to_prefill.send(req_bytes, copy=False) + + # Wait for the prefill to be done. + response = await q.get() + assert response.request_id == request.request_id + if not response.success: + # TODO: actual error handling and shutdown. + raise Exception("Failed Prefill Request.") + + async def _decode( + self, + request: PDRequest, + q: asyncio.Queue[PDResponse], + ) -> AsyncGenerator[PDRequest]: + + # Send request to the decode instance. + req_bytes = self.encoder.encode(request) + await self.to_decode.send(req_bytes, copy=False) + + # Iterate response queue and yield each response to caller.. + finished = False + while not finished: + response = await q.get() + logger.debug(f"{response}") + if not response.success: + # TODO: actual error handling and shutdown. + raise Exception("Failed Decode Request.") + finished = response.finish_reason is not None + yield response + + def _to_request_output( + self, + pd_response: PDResponse, + prompt_token_ids: List[int], + ) -> RequestOutput: + finished = pd_response.finish_reason is not None + return RequestOutput( + request_id=pd_response.request_id, + prompt=None, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=[CompletionOutput( + index=0, + text=pd_response.text, + token_ids=pd_response.token_ids, + cumulative_logprob=None, + logprobs=None, + finish_reason=pd_response.finish_reason, + stop_reason=pd_response.stop_reason + )], + finished=finished, + ) + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput]: + # Start loops on first request. + if self.output_handler is None: + self.output_handler = asyncio.create_task(self._run_output_handler()) + self.log_running = asyncio.create_task(self._run_log_running()) + + # TODO: Expand to support the full matrix. + if not "prompt_token_ids" in prompt: + raise NotImplementedError( + "We currently only support TokensPrompt for P/D!") + if lora_request is not None: + raise NotImplementedError( + "We currently do not suppport LoRA for P/D!") + if trace_headers is not None: + raise NotImplementedError( + "We currently do not suppport tracing for P/D!") + if prompt_adapter_request is not None: + raise NotImplementedError( + "We currently do not suppport prompt adapter for P/D!") + if priority != 0: + raise NotImplementedError( + "We currently do not support priority for P/D!") + if request_id in self.queues: + raise ValueError(f"Found duplicate request_id: {request_id}!") + + # Queue to gather output from output_handler. + q: asyncio.Queue[PDResponse] = asyncio.Queue() + self.queues[request_id] = q + + # (1) Perform the Prefill. + original_max_tokens = sampling_params.max_tokens + prompt_token_ids = prompt["prompt_token_ids"] + request = PDRequest( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + request.sampling_params.max_tokens = 1 + logger.debug("Sending Prefill: %s", request.request_id) + pd_response = await self._prefill(request, q) + + # (2) Perform the Decodes. + logger.debug("Sending Decode: %s", request.request_id) + request.sampling_params.max_tokens = original_max_tokens + async for pd_response in self._decode(request, q): + logger.debug("Got Decode: %s", request.request_id) + yield self._to_request_output(pd_response, prompt_token_ids) + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + raise NotImplementedError + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + raise NotImplementedError + + async def abort(self, request_id: str) -> None: + raise NotImplementedError + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def get_decoding_config(self) -> DecodingConfig: + raise NotImplementedError + + async def get_input_preprocessor(self) -> InputPreprocessor: + raise NotImplementedError + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if lora_request is not None: + raise NotImplementedError( + "LoRA is not yet supported in the PDEngine.") + return self.tokenizer.get_lora_tokenizer(None) + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + pass + + async def start_profile(self) -> None: + raise NotImplementedError + + async def stop_profile(self) -> None: + raise NotImplementedError + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + raise NotImplementedError + + async def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up(self) -> None: + raise NotImplementedError + + async def is_sleeping(self) -> bool: + False + + async def add_lora(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + + @property + def errored(self) -> bool: + return False \ No newline at end of file diff --git a/vllm/entrypoints/disaggregated/types.py b/vllm/entrypoints/disaggregated/types.py new file mode 100644 index 0000000000000..79263ba88893d --- /dev/null +++ b/vllm/entrypoints/disaggregated/types.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +import msgspec +from typing import List, Optional +from vllm import SamplingParams + +# NOTE FOR DEVELOPERS: +# DO NOT USE PICKLE FOR THESE CLASSES. IN A MULTI NODE +# SETUP WE WILL USE TCP. WE CANNOT USE PICKLE OTHERWISE +# WE RISK REMOTE CODE EXECUTION FROM UNSTRUSTED USERS. + +class PDRequest(msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + request_id: str + prompt_token_ids: List[int] + sampling_params: SamplingParams + # TODO: support multimodal inputs. + +class PDResponse(msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] + request_id: str + success: bool + text: str + token_ids: List[int] + finish_reason: Optional[str] = None + stop_reason: Optional[str] = None + logprobs = None # TODO \ No newline at end of file diff --git a/vllm/entrypoints/disaggregated/worker.py b/vllm/entrypoints/disaggregated/worker.py new file mode 100644 index 0000000000000..80b1d614bbe21 --- /dev/null +++ b/vllm/entrypoints/disaggregated/worker.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import msgspec +import signal +import uvloop +from typing import Optional + +import zmq +import zmq.asyncio + +from vllm.inputs.data import TokensPrompt +from vllm.engine.async_llm_engine import AsyncEngineArgs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.disaggregated.worker') + + +async def handle_request( + request: PDRequest, + engine: EngineClient, + socket: zmq.asyncio.Socket, + encoder: msgspec.msgpack.Encoder, +) -> None: + request_id = request.request_id + try: + # 1) Generate RequestOutputs. + prompt: TokensPrompt = { + "prompt_token_ids": request.prompt_token_ids} + async for request_output in engine.generate( + prompt=prompt, + sampling_params=request.sampling_params, + request_id=request_id): + + assert len(request_output.outputs) == 1, "Only support N=1 right now." + out = request_output.outputs[0] + + # 2) Convert RequestOutput --> PDResponse. + response = PDResponse( + request_id=request_id, + success=True, + text=out.text, + token_ids=out.token_ids, + finish_reason=out.finish_reason, + stop_reason=out.stop_reason, + ) + response_bytes = encoder.encode(response) + + # 3) Send to Connector. + logger.info("Sending: %s", request_id) + await socket.send(response_bytes, copy=False) + logger.info("Sent: %s", request_id) + + except Exception as e: + # TODO: actual error handling. + logger.error("Exception in Worker Routine: %s request_id: %s", e, + request_id) + response = PDResponse(request_id=request_id, success=False) + response_bytes = encoder.encode(response) + await socket.send(response, copy=False) + +async def run_server(args, engine: EngineClient): + """Get Requests and Handle Them.""" + logger.info("P/D Worker is Ready To Recieve Requests.") + + running_requests: set[asyncio.Task] = set() + decoder = msgspec.msgpack.Decoder(PDRequest) + encoder = msgspec.msgpack.Encoder() + + ctx: Optional[zmq.asyncio.Context] = None + try: + # IPC Setup. + ctx = zmq.asyncio.Context() + from_connector = ctx.socket(zmq.constants.PULL) + from_connector.connect(f"ipc://{args.worker_addr}") + to_connector = ctx.socket(zmq.constants.PUSH) + to_connector.connect(f"ipc://{args.connector_addr}") + + # Main Loop. + while True: + # 1) Get request from the Connector. + pd_request_bytes = await from_connector.recv() + pd_request = decoder.decode(pd_request_bytes) + + # 2) Launch a coroutine to handle the request. + task = asyncio.create_task(handle_request( + pd_request, engine, to_connector, encoder)) + running_requests.add(task) + task.add_done_callback(running_requests.discard) + + except KeyboardInterrupt: + logger.debug("Worker server loop interrupted.") + + finally: + for task in running_requests: + task.cancel() + if ctx is not None: + ctx.destroy(linger=0) + + +async def main(args) -> None: + logger.info("vLLM P/D Worker Server %s", VLLM_VERSION) + logger.info("args: %s", args) + + # Workaround to avoid footguns where uvicorn drops requests + # with too many concurrent requests active due to ulimit. + set_ulimit() + + # Interrupt on sigterm during initialization. + def signal_handler(*_) -> None: + raise KeyboardInterrupt("terminated") + signal.signal(signal.SIGTERM, signal_handler) + + args.disable_frontend_multiprocessing = False + async with build_async_engine_client(args) as engine: + await run_server(args, engine) + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument('--connector-addr', + type=str, + required=True, + help='The address of the connector.') + parser.add_argument('--worker-addr', + type=str, + required=True, + help='The address of the worker.') + AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + uvloop.run(main(args))