Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-24 00:46:46 +00:00
parent 7954461d4c
commit 70e06dd574
5 changed files with 40 additions and 35 deletions

View File

@ -37,7 +37,7 @@ wait_for_server() {
wait_for_disagg_server() {
local log_file=$1
timeout 1200 bash -c "
until grep -q 'zmq Server started at' $log_file; do
until grep -q 'PD Worker is ready' $log_file; do
sleep 1
done" && return 0 || return 1
}

View File

@ -53,9 +53,9 @@ class PDController(EngineClient):
|
[ zmq ]
|
[ PDWorker ] [ PDWorker ]
|
[ Engine ] <---> [ Engine ]
[ PDWorker ] [ PDWorker ]
| |
[ Engine ] <-kv-> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler and service
@ -80,13 +80,12 @@ class PDController(EngineClient):
# TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context()
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.bind(f"{decode_addr}")
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
self.to_prefill.bind(f"{prefill_addr}")
self.to_prefill.connect(prefill_addr)
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.connect(decode_addr)
self.controller_addr = controller_addr
self.decode_addr = decode_addr
self.prefill_addr = prefill_addr
self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
@ -123,8 +122,7 @@ class PDController(EngineClient):
if (task := self.output_handler) is not None:
task.cancel()
ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr]
for path in ipc_paths:
for path in self.ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
@ -178,7 +176,7 @@ class PDController(EngineClient):
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
logger.debug("Prefill Response: %s", request.request_id)
async def _run_decode(
self,
@ -196,7 +194,7 @@ class PDController(EngineClient):
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
logger.debug("Decode Response: %s", request.request_id)
finished = response.finish_reason is not None
yield response
@ -269,11 +267,9 @@ class PDController(EngineClient):
prompt_token_ids=prompt["prompt_token_ids"],
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes.
logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._run_decode(request, q):
yield self._to_request_output(pd_response,

View File

@ -8,7 +8,8 @@ import zmq
import zmq.asyncio
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
PDGenerationResponse, PDRequestType)
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
@ -34,13 +35,13 @@ class PDWorker:
self.engine = engine
# ZMQ IPC.
self.worker_addr = worker_addr
self.controller_addr = controller_addr
self.worker_addr = f"ipc://{worker_addr}"
self.controller_addr = f"ipc://{controller_addr}"
self.ctx = zmq.asyncio.Context()
self.from_client = self.ctx.socket(zmq.constants.PULL)
self.from_client.connect(f"ipc://{self.worker_addr}")
self.to_client = self.ctx.socket(zmq.constants.PUSH)
self.to_client.connect(f"ipc://{self.controller_addr}")
self.from_controller = self.ctx.socket(zmq.constants.PULL)
self.from_controller.bind(self.worker_addr)
self.to_controller = self.ctx.socket(zmq.constants.PUSH)
self.to_controller.connect(self.controller_addr)
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
self.encoder = msgspec.msgpack.Encoder()
@ -71,9 +72,12 @@ class PDWorker:
"""
logger.info("PDWorker is ready To handle requests.")
poller = zmq.asyncio.Poller()
poller.register(self.from_controller, zmq.POLLIN)
while True:
# 1) Get request from the Connector.
req_type, req_data = await self.from_client.recv_multipart()
req_type, req_data = await self.from_controller.recv_multipart()
# 2) Handle the request.
await self._handle_request(req_type, req_data)
@ -85,11 +89,13 @@ class PDWorker:
2) call the appropriate handler for the request type
"""
if req_type == PDRequestType.GENERATION:
req = self.decode_generation(req_data)
req = self.decode_generation.decode(req_data)
await self._generation_handler(req)
elif req_type == PDRequestType.ABORT:
req = self.decode_abort(req_data)
req = self.decode_abort.decode(req_data)
await self._abort_handler(req)
else:
raise Exception(f"Unknown Request Type: {req_type}.")
async def _generation_handler(self, req: PDGenerationRequest):
"""
@ -133,7 +139,5 @@ class PDWorker:
# 4) Serialize and send to PDConroller.
response_bytes = self.encoder.encode(response)
msg = [PDGenerationResponse.SUCCE, response_bytes]
logger.debug("Sending: %s", request_id)
await self.to_client.send_multipart(msg, copy=False)
logger.debug("Sent: %s", request_id)
msg = (PDResponseType.GENERATION, response_bytes)
await self.to_controller.send_multipart(msg, copy=False)

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
"""
Toy connector for prototyping.
Toy API Server for prototyping.
When PDConroller supports the protocol and we clean up the
OpenAI Server, we can drop this in favor of vllm serve.
Once the PDController is more mature and we clean up
the OpenAI Server at bit, we can put the PDController
directly inside and launch with vllm serve.
"""
from collections.abc import AsyncIterator

View File

@ -10,12 +10,14 @@ from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
logger = init_logger("vllm.entrypoints.disaggregated.worker")
async def run(args, engine: EngineClient):
try:
worker = PDWorker(engine, args.worker_addr, args.controller_addr)
worker = PDWorker(engine=engine,
worker_addr=args.worker_addr,
controller_addr=args.controller_addr)
await worker.run_busy_loop()
finally:
worker.shutdown()
@ -25,7 +27,6 @@ async def main(args) -> None:
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
logger.info("Args: %s", args)
args.disable_frontend_multiprocessing = False
async with build_async_engine_client(args) as engine:
await run(args, engine)
@ -40,5 +41,8 @@ if __name__ == "__main__":
type=str,
required=True,
help='The address of the worker.')
parser.add_argument('--disable-frontend-multiprocessing',
action="store_true",
help='Disable MQLLMEngine for AsyncLLMEngine.')
AsyncEngineArgs.add_cli_args(parser)
uvloop.run(main(parser.parse_args()))