From 66349c33a1aaec0d081ecec8d799def32794b3bc Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 23 Mar 2025 22:36:57 +0000 Subject: [PATCH] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/disaggregated/pd_client.py | 131 ++++++++++++-------- vllm/disaggregated/pd_worker.py | 27 ++-- vllm/disaggregated/protocol.py | 6 - vllm/entrypoints/disaggregated/connector.py | 36 +++--- 4 files changed, 119 insertions(+), 81 deletions(-) diff --git a/vllm/disaggregated/pd_client.py b/vllm/disaggregated/pd_client.py index 9d00a29030133..98657d0e21432 100644 --- a/vllm/disaggregated/pd_client.py +++ b/vllm/disaggregated/pd_client.py @@ -3,7 +3,7 @@ import asyncio import os from collections.abc import AsyncGenerator, Mapping -from typing import Optional +from typing import Optional, Union import msgspec import zmq @@ -11,7 +11,10 @@ import zmq.asyncio from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse +from vllm.disaggregated.protocol import (PDGenerationRequest, + PDGenerationResponse, PDRequestType, + PDResponseType) +from vllm.engine.protocol import EngineClient from vllm.inputs.data import PromptType from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -30,14 +33,36 @@ logger = init_logger(__name__) DEFAULT_MAX_TOKENS = 32000 -class PDClient: +class PDController(EngineClient): """ - PDEngine: - Equiavlent of AsyncLLM for P/D. Assumes there is - a Prefill and Decode service already running. + Controller that schedules work on the PDWorkers. - * TODO: actually handle errors and failure. - * TODO: support more than just text input. + Conforms for the EngineClient protocol so it can + be wrapped with the OpenAI Server. + + Two Phases: + * Send request to prefill worker, await ack. + * Send request to decode worker. + + KVSync happens directly between Engines, + handled by vLLM KVCacheTransfer. + + [ OpenAI Server ] + | + [ PDController ] + | + [ zmq ] + | + [ PDWorker ] [ PDWorker ] + | + [ Engine ] <---> [ Engine ] + + After PR #12957, we will support xPyD, so we will + also need to implement a scheduler. + we will need to support multiple + + * TODO: actually handle errors and failure. + * TODO: support the full API (logprobs, multimodal). """ def __init__(self, prefill_addr: str, decode_addr: str, @@ -49,6 +74,8 @@ class PDClient: self.encoder = msgspec.msgpack.Encoder() # ZMQ communication. + # TODO: once https://github.com/vllm-project/vllm/pull/12957 + # lands, do service discovery to scale out workers. self.ctx = zmq.asyncio.Context() self.to_decode = self.ctx.socket(zmq.constants.PUSH) self.to_decode.bind(f"{decode_addr}") @@ -63,7 +90,7 @@ class PDClient: self.log_running: Optional[asyncio.Task] = None # Dummy: needed for EngineClient Protocol. - # TODO: refactor EngineClient to avoid needing this. + # TODO: refactor OAI Server to avoid needing this. self.model_config = ModelConfig(model=model_name, tokenizer=model_name, tokenizer_mode="auto", @@ -73,7 +100,7 @@ class PDClient: seed=42) # Dummy: needed for EngineClient Protocol. - # TODO: refactor EngineClient to avoid needing this. + # TODO: refactor OAI Server to avoid needing this. init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=False, @@ -109,7 +136,7 @@ class PDClient: Pull responses from Decode + Prefill engines and distribute back to the generate() tasks. """ - decoder = msgspec.msgpack.Decoder(PDResponse) + decoder = msgspec.msgpack.Decoder(PDGenerationResponse) socket: Optional[zmq.asyncio.Socket] = None try: @@ -117,72 +144,78 @@ class PDClient: socket.bind(self.connector_addr) while True: - reponse_bytes = await socket.recv() - response = decoder.decode(reponse_bytes) - logger.debug("Got Response: %s", response.request_id) - self.queues[response.request_id].put_nowait(response) - except: - # TODO: actually handle failure and shutdown. - raise + res_type, res_data = await socket.recv_multipart() + if res_type == PDResponseType.FAILURE: + raise Exception("Failure Response from PDWorker.") + elif res_type == PDResponseType.GENERATION: + response = decoder.decode(res_data) + logger.debug("Got Response: %s", response.request_id) + self.queues[response.request_id].put_nowait(response) + else: + raise Exception("Unknown response type.") + except Exception as e: + # TODO: distinguish between fatal and non-fatal errors. + for _, q in self.queues.values(): + q.put_nowait(e) + raise e finally: if socket is not None: socket.close(linger=0) - async def _prefill( + async def _run_prefill( self, - request: PDRequest, - q: asyncio.Queue[PDResponse], + request: PDGenerationRequest, + q: asyncio.Queue[Union[Exception, PDGenerationResponse]], ): # Send request to the prefill instance. req_bytes = self.encoder.encode(request) - await self.to_prefill.send(req_bytes, copy=False) + msg = (PDRequestType.GENERATION, req_bytes) + await self.to_prefill.send_multipart(msg, copy=False) # Wait for the prefill to be done. response = await q.get() - assert response.request_id == request.request_id - if not response.success: - # TODO: actual error handling and shutdown. - raise Exception("Failed Prefill Request.") + if isinstance(response, Exception): + raise response - async def _decode( + async def _run_decode( self, - request: PDRequest, - q: asyncio.Queue[PDResponse], - ) -> AsyncGenerator[PDRequest]: + request: PDGenerationRequest, + q: asyncio.Queue[Union[Exception, PDGenerationResponse]], + ) -> AsyncGenerator[PDGenerationResponse]: # Send request to the decode instance. req_bytes = self.encoder.encode(request) - await self.to_decode.send(req_bytes, copy=False) + msg = (PDRequestType.GENERATION, req_bytes) + await self.to_decode.send_multipart(msg, copy=False) - # Iterate response queue and yield each response to caller.. + # Iterate response queue and yield each response to caller. finished = False while not finished: response = await q.get() - if not response.success: - # TODO: actual error handling and shutdown. - raise Exception("Failed Decode Request.") + if isinstance(response, Exception): + raise response finished = response.finish_reason is not None yield response def _to_request_output( self, - pd_response: PDResponse, + response: PDGenerationResponse, prompt_token_ids: list[int], ) -> RequestOutput: - finished = pd_response.finish_reason is not None + finished = response.finish_reason is not None return RequestOutput( - request_id=pd_response.request_id, + request_id=response.request_id, prompt=None, prompt_token_ids=prompt_token_ids, prompt_logprobs=None, outputs=[ CompletionOutput(index=0, - text=pd_response.text, - token_ids=pd_response.token_ids, + text=response.text, + token_ids=response.token_ids, cumulative_logprob=None, logprobs=None, - finish_reason=pd_response.finish_reason, - stop_reason=pd_response.stop_reason) + finish_reason=response.finish_reason, + stop_reason=response.stop_reason) ], finished=finished, ) @@ -223,24 +256,24 @@ class PDClient: raise ValueError(f"Found duplicate request_id: {request_id}!") # Queue to gather output from output_handler. - q: asyncio.Queue[PDResponse] = asyncio.Queue() + q = asyncio.Queue() self.queues[request_id] = q # (1) Perform the Prefill. original_max_tokens = sampling_params.max_tokens prompt_token_ids = prompt["prompt_token_ids"] - request = PDRequest(request_id=request_id, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) + request = PDGenerationRequest(request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) request.sampling_params.max_tokens = 1 logger.debug("Sending Prefill: %s", request.request_id) - pd_response = await self._prefill(request, q) + pd_response = await self._run_prefill(request, q) # (2) Perform the Decodes. logger.debug("Sending Decode: %s", request.request_id) request.sampling_params.max_tokens = original_max_tokens - async for pd_response in self._decode(request, q): - logger.debug("Got Decode: %s", request.request_id) + async for pd_response in self._run_decode(request, q): + logger.debug("Got Response: %s", request.request_id) yield self._to_request_output(pd_response, prompt_token_ids) async def beam_search( diff --git a/vllm/disaggregated/pd_worker.py b/vllm/disaggregated/pd_worker.py index e5e4aea30e6f4..e24240d83f3df 100644 --- a/vllm/disaggregated/pd_worker.py +++ b/vllm/disaggregated/pd_worker.py @@ -65,9 +65,9 @@ class PDWorker: async def run_busy_loop(self): """ - Main execution loop for the PDWorker: - * 1) Wait for a request from the PDClient. - * 2) Handle the request. + main loop: + 1) wait for a request from the PDClient + 2) handle the request """ logger.info("PDWorker is ready To handle requests.") @@ -76,10 +76,14 @@ class PDWorker: req_type, req_data = await self.from_client.recv_multipart() # 2) Handle the request. - await self._handle_request(req_type, req_data.buffer) + await self._handle_request(req_type, req_data) async def _handle_request(self, req_type: bytes, req_data: bytes): - """Parse the request type and call the appropriate handler.""" + """ + request handler: + 1) parse the request type + 2) call the appropriate handler for the request type + """ if req_type == PDRequestType.GENERATION: req = self.decode_generation(req_data) await self._generation_handler(req) @@ -88,13 +92,18 @@ class PDWorker: await self._abort_handler(req) async def _generation_handler(self, req: PDGenerationRequest): - """Launch generation in a background task.""" + """ + Handle a PDGenerationRequest by launching a task. + """ task = asyncio.create_task(self._generate(req)) self.running_requests.add(task) task.add_done_callback(self.running_requests.discard) async def _abort_handler(self, req: PDGenerationRequest): - """Abort the request in the engine.""" + """ + Handle a PDAbortRequest by cancelling the running task. + The _generate coro aborts in the Engine. + """ # Convert running_requests set() into a dict(), keyed # by request_id. Cancel the task when an abort comes in. # Then update the _generate coroutine to handle a @@ -103,11 +112,11 @@ class PDWorker: async def _generate(self, req: PDGenerationRequest): """ - Handle a single PDRequest: + Handle a single PDGenerationRequest: * 1) submit request to AsyncLLM * 2) iterate the RequestOutputs * 3) convert RequestOutput --> PDResponse - * 4) serialize and send to Connector. + * 4) serialize and send to PDClient """ request_id = req.request_id diff --git a/vllm/disaggregated/protocol.py b/vllm/disaggregated/protocol.py index dfc02263207ba..ee7dde519acf3 100644 --- a/vllm/disaggregated/protocol.py +++ b/vllm/disaggregated/protocol.py @@ -55,9 +55,3 @@ class PDGenerationResponse(msgspec.Struct): finish_reason=out.finish_reason, stop_reason=out.stop_reason, ) - - -class PDGenerationFailure(msgspec.Struct): - request_id: str - error_message: str - engine_dead: bool diff --git a/vllm/entrypoints/disaggregated/connector.py b/vllm/entrypoints/disaggregated/connector.py index df8f1b4f86bdf..c098c433e08cf 100644 --- a/vllm/entrypoints/disaggregated/connector.py +++ b/vllm/entrypoints/disaggregated/connector.py @@ -1,5 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -"""Toy Server For Prototyping.""" +""" +Toy connector for prototyping. + +When PDClient supports the protocol and we clean up the +OpenAI Server, we can drop this in favor of vllm serve. +""" from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -9,7 +14,7 @@ import uvloop from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.entrypoints.disaggregated.engine import PDEngine +from vllm.disaggregated.pd_client import PDClient from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, ErrorResponse) @@ -46,34 +51,31 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @asynccontextmanager -async def pd_engine_client_ctx_manager( - prefill_addr: str, decode_addr: str, connector_addr: str, - model_name: str) -> AsyncIterator[PDEngine]: - engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name) - yield engine - engine.shutdown() +async def pd_client_ctx(prefill_addr: str, decode_addr: str, + connector_addr: str, + model_name: str) -> AsyncIterator[PDClient]: + client = PDClient(prefill_addr, decode_addr, connector_addr, model_name) + yield client + client.shutdown() async def main(args, **uvicorn_kwargs): - logger.info("vLLM Disaggregate Connector Start %s %s", args, + logger.info("vLLM Disaggregated Connector Start %s %s", args, uvicorn_kwargs) # Avoid dropping requests under high concurrency. set_ulimit() # IPC Paths. - # NOTE FOR DEVELOPERS: when shifting to TCP, ensure you - # are not using pickle to avoid RCE security flaw. prefill_addr = f"ipc://{args.prefill_addr}" decode_addr = f"ipc://{args.decode_addr}" connector_addr = f"ipc://{args.connector_addr}" # Start Engine. - async with pd_engine_client_ctx_manager( - prefill_addr=prefill_addr, - decode_addr=decode_addr, - connector_addr=connector_addr, - model_name=args.model) as engine_client: + async with pd_client_ctx(prefill_addr=prefill_addr, + decode_addr=decode_addr, + connector_addr=connector_addr, + model_name=args.model) as engine_client: # Initialize App State. model_config = await engine_client.get_model_config() @@ -93,7 +95,7 @@ async def main(args, **uvicorn_kwargs): ) # Run Server. - config = uvicorn.Config(app, host="0.0.0.0", port=args.port) + config = uvicorn.Config(app, host=args.host, port=args.port) server = uvicorn.Server(config) await server.serve()