diff --git a/examples/online_serving/disaggregated_prefill_zmq.sh b/examples/online_serving/disaggregated_prefill_zmq.sh index adea0c6929fb0..12c1108101b53 100644 --- a/examples/online_serving/disaggregated_prefill_zmq.sh +++ b/examples/online_serving/disaggregated_prefill_zmq.sh @@ -37,7 +37,7 @@ wait_for_server() { wait_for_disagg_server() { local log_file=$1 timeout 1200 bash -c " - until grep -q 'zmq Server started at' $log_file; do + until grep -q 'PD Worker is ready' $log_file; do sleep 1 done" && return 0 || return 1 } diff --git a/vllm/disaggregated/pd_controller.py b/vllm/disaggregated/pd_controller.py index ca6d8fe10781d..041468f9a6a4a 100644 --- a/vllm/disaggregated/pd_controller.py +++ b/vllm/disaggregated/pd_controller.py @@ -53,9 +53,9 @@ class PDController(EngineClient): | [ zmq ] | - [ PDWorker ] [ PDWorker ] - | - [ Engine ] <---> [ Engine ] + [ PDWorker ] [ PDWorker ] + | | + [ Engine ] <-kv-> [ Engine ] After PR #12957, we will support xPyD, so we will also need to implement a scheduler and service @@ -80,13 +80,12 @@ class PDController(EngineClient): # TODO: once https://github.com/vllm-project/vllm/pull/12957 # lands, do service discovery to scale out workers. self.ctx = zmq.asyncio.Context() - self.to_decode = self.ctx.socket(zmq.constants.PUSH) - self.to_decode.bind(f"{decode_addr}") self.to_prefill = self.ctx.socket(zmq.constants.PUSH) - self.to_prefill.bind(f"{prefill_addr}") + self.to_prefill.connect(prefill_addr) + self.to_decode = self.ctx.socket(zmq.constants.PUSH) + self.to_decode.connect(decode_addr) self.controller_addr = controller_addr - self.decode_addr = decode_addr - self.prefill_addr = prefill_addr + self.ipc_paths = [prefill_addr, decode_addr, controller_addr] # Background loops (started on first generate()). self.output_handler: Optional[asyncio.Task] = None @@ -123,8 +122,7 @@ class PDController(EngineClient): if (task := self.output_handler) is not None: task.cancel() - ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr] - for path in ipc_paths: + for path in self.ipc_paths: socket_path = path.replace("ipc://", "") if os.path.exists(socket_path): os.remove(socket_path) @@ -178,7 +176,7 @@ class PDController(EngineClient): response = await q.get() if isinstance(response, Exception): raise response - logger.debug("Got Decode Response: %s", request.request_id) + logger.debug("Prefill Response: %s", request.request_id) async def _run_decode( self, @@ -196,7 +194,7 @@ class PDController(EngineClient): response = await q.get() if isinstance(response, Exception): raise response - logger.debug("Got Decode Response: %s", request.request_id) + logger.debug("Decode Response: %s", request.request_id) finished = response.finish_reason is not None yield response @@ -269,11 +267,9 @@ class PDController(EngineClient): prompt_token_ids=prompt["prompt_token_ids"], sampling_params=sampling_params) request.sampling_params.max_tokens = 1 - logger.debug("Sending Prefill: %s", request.request_id) pd_response = await self._run_prefill(request, q) # (2) Perform the Decodes. - logger.debug("Sending Decode: %s", request.request_id) request.sampling_params.max_tokens = original_max_tokens async for pd_response in self._run_decode(request, q): yield self._to_request_output(pd_response, diff --git a/vllm/disaggregated/pd_worker.py b/vllm/disaggregated/pd_worker.py index 99c3302b5b886..3a8c9ba1cc1b8 100644 --- a/vllm/disaggregated/pd_worker.py +++ b/vllm/disaggregated/pd_worker.py @@ -8,7 +8,8 @@ import zmq import zmq.asyncio from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest, - PDGenerationResponse, PDRequestType) + PDGenerationResponse, PDRequestType, + PDResponseType) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger @@ -34,13 +35,13 @@ class PDWorker: self.engine = engine # ZMQ IPC. - self.worker_addr = worker_addr - self.controller_addr = controller_addr + self.worker_addr = f"ipc://{worker_addr}" + self.controller_addr = f"ipc://{controller_addr}" self.ctx = zmq.asyncio.Context() - self.from_client = self.ctx.socket(zmq.constants.PULL) - self.from_client.connect(f"ipc://{self.worker_addr}") - self.to_client = self.ctx.socket(zmq.constants.PUSH) - self.to_client.connect(f"ipc://{self.controller_addr}") + self.from_controller = self.ctx.socket(zmq.constants.PULL) + self.from_controller.bind(self.worker_addr) + self.to_controller = self.ctx.socket(zmq.constants.PUSH) + self.to_controller.connect(self.controller_addr) self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest) self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest) self.encoder = msgspec.msgpack.Encoder() @@ -71,9 +72,12 @@ class PDWorker: """ logger.info("PDWorker is ready To handle requests.") + poller = zmq.asyncio.Poller() + poller.register(self.from_controller, zmq.POLLIN) + while True: # 1) Get request from the Connector. - req_type, req_data = await self.from_client.recv_multipart() + req_type, req_data = await self.from_controller.recv_multipart() # 2) Handle the request. await self._handle_request(req_type, req_data) @@ -85,11 +89,13 @@ class PDWorker: 2) call the appropriate handler for the request type """ if req_type == PDRequestType.GENERATION: - req = self.decode_generation(req_data) + req = self.decode_generation.decode(req_data) await self._generation_handler(req) elif req_type == PDRequestType.ABORT: - req = self.decode_abort(req_data) + req = self.decode_abort.decode(req_data) await self._abort_handler(req) + else: + raise Exception(f"Unknown Request Type: {req_type}.") async def _generation_handler(self, req: PDGenerationRequest): """ @@ -133,7 +139,5 @@ class PDWorker: # 4) Serialize and send to PDConroller. response_bytes = self.encoder.encode(response) - msg = [PDGenerationResponse.SUCCE, response_bytes] - logger.debug("Sending: %s", request_id) - await self.to_client.send_multipart(msg, copy=False) - logger.debug("Sent: %s", request_id) + msg = (PDResponseType.GENERATION, response_bytes) + await self.to_controller.send_multipart(msg, copy=False) diff --git a/vllm/entrypoints/disaggregated/api_server.py b/vllm/entrypoints/disaggregated/api_server.py index f2f925d6716fb..e4537d34725cc 100644 --- a/vllm/entrypoints/disaggregated/api_server.py +++ b/vllm/entrypoints/disaggregated/api_server.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 """ -Toy connector for prototyping. +Toy API Server for prototyping. -When PDConroller supports the protocol and we clean up the -OpenAI Server, we can drop this in favor of vllm serve. +Once the PDController is more mature and we clean up +the OpenAI Server at bit, we can put the PDController +directly inside and launch with vllm serve. """ from collections.abc import AsyncIterator diff --git a/vllm/entrypoints/disaggregated/worker.py b/vllm/entrypoints/disaggregated/worker.py index 041101d62dfab..5ec2410d8840a 100644 --- a/vllm/entrypoints/disaggregated/worker.py +++ b/vllm/entrypoints/disaggregated/worker.py @@ -10,12 +10,14 @@ from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION -logger = init_logger(__name__) +logger = init_logger("vllm.entrypoints.disaggregated.worker") async def run(args, engine: EngineClient): try: - worker = PDWorker(engine, args.worker_addr, args.controller_addr) + worker = PDWorker(engine=engine, + worker_addr=args.worker_addr, + controller_addr=args.controller_addr) await worker.run_busy_loop() finally: worker.shutdown() @@ -25,7 +27,6 @@ async def main(args) -> None: logger.info("vLLM P/D Worker Server %s", VLLM_VERSION) logger.info("Args: %s", args) - args.disable_frontend_multiprocessing = False async with build_async_engine_client(args) as engine: await run(args, engine) @@ -40,5 +41,8 @@ if __name__ == "__main__": type=str, required=True, help='The address of the worker.') + parser.add_argument('--disable-frontend-multiprocessing', + action="store_true", + help='Disable MQLLMEngine for AsyncLLMEngine.') AsyncEngineArgs.add_cli_args(parser) uvloop.run(main(parser.parse_args()))