Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-24 00:46:46 +00:00
parent 7954461d4c
commit 70e06dd574
5 changed files with 40 additions and 35 deletions

View File

@ -37,7 +37,7 @@ wait_for_server() {
wait_for_disagg_server() { wait_for_disagg_server() {
local log_file=$1 local log_file=$1
timeout 1200 bash -c " timeout 1200 bash -c "
until grep -q 'zmq Server started at' $log_file; do until grep -q 'PD Worker is ready' $log_file; do
sleep 1 sleep 1
done" && return 0 || return 1 done" && return 0 || return 1
} }

View File

@ -53,9 +53,9 @@ class PDController(EngineClient):
| |
[ zmq ] [ zmq ]
| |
[ PDWorker ] [ PDWorker ] [ PDWorker ] [ PDWorker ]
| | |
[ Engine ] <---> [ Engine ] [ Engine ] <-kv-> [ Engine ]
After PR #12957, we will support xPyD, so we will After PR #12957, we will support xPyD, so we will
also need to implement a scheduler and service also need to implement a scheduler and service
@ -80,13 +80,12 @@ class PDController(EngineClient):
# TODO: once https://github.com/vllm-project/vllm/pull/12957 # TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers. # lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context() self.ctx = zmq.asyncio.Context()
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.bind(f"{decode_addr}")
self.to_prefill = self.ctx.socket(zmq.constants.PUSH) self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
self.to_prefill.bind(f"{prefill_addr}") self.to_prefill.connect(prefill_addr)
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.connect(decode_addr)
self.controller_addr = controller_addr self.controller_addr = controller_addr
self.decode_addr = decode_addr self.ipc_paths = [prefill_addr, decode_addr, controller_addr]
self.prefill_addr = prefill_addr
# Background loops (started on first generate()). # Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
@ -123,8 +122,7 @@ class PDController(EngineClient):
if (task := self.output_handler) is not None: if (task := self.output_handler) is not None:
task.cancel() task.cancel()
ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr] for path in self.ipc_paths:
for path in ipc_paths:
socket_path = path.replace("ipc://", "") socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path): if os.path.exists(socket_path):
os.remove(socket_path) os.remove(socket_path)
@ -178,7 +176,7 @@ class PDController(EngineClient):
response = await q.get() response = await q.get()
if isinstance(response, Exception): if isinstance(response, Exception):
raise response raise response
logger.debug("Got Decode Response: %s", request.request_id) logger.debug("Prefill Response: %s", request.request_id)
async def _run_decode( async def _run_decode(
self, self,
@ -196,7 +194,7 @@ class PDController(EngineClient):
response = await q.get() response = await q.get()
if isinstance(response, Exception): if isinstance(response, Exception):
raise response raise response
logger.debug("Got Decode Response: %s", request.request_id) logger.debug("Decode Response: %s", request.request_id)
finished = response.finish_reason is not None finished = response.finish_reason is not None
yield response yield response
@ -269,11 +267,9 @@ class PDController(EngineClient):
prompt_token_ids=prompt["prompt_token_ids"], prompt_token_ids=prompt["prompt_token_ids"],
sampling_params=sampling_params) sampling_params=sampling_params)
request.sampling_params.max_tokens = 1 request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._run_prefill(request, q) pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes. # (2) Perform the Decodes.
logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._run_decode(request, q): async for pd_response in self._run_decode(request, q):
yield self._to_request_output(pd_response, yield self._to_request_output(pd_response,

View File

@ -8,7 +8,8 @@ import zmq
import zmq.asyncio import zmq.asyncio
from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest, from vllm.disaggregated.protocol import (PDAbortRequest, PDGenerationRequest,
PDGenerationResponse, PDRequestType) PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
@ -34,13 +35,13 @@ class PDWorker:
self.engine = engine self.engine = engine
# ZMQ IPC. # ZMQ IPC.
self.worker_addr = worker_addr self.worker_addr = f"ipc://{worker_addr}"
self.controller_addr = controller_addr self.controller_addr = f"ipc://{controller_addr}"
self.ctx = zmq.asyncio.Context() self.ctx = zmq.asyncio.Context()
self.from_client = self.ctx.socket(zmq.constants.PULL) self.from_controller = self.ctx.socket(zmq.constants.PULL)
self.from_client.connect(f"ipc://{self.worker_addr}") self.from_controller.bind(self.worker_addr)
self.to_client = self.ctx.socket(zmq.constants.PUSH) self.to_controller = self.ctx.socket(zmq.constants.PUSH)
self.to_client.connect(f"ipc://{self.controller_addr}") self.to_controller.connect(self.controller_addr)
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest) self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest) self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
@ -71,9 +72,12 @@ class PDWorker:
""" """
logger.info("PDWorker is ready To handle requests.") logger.info("PDWorker is ready To handle requests.")
poller = zmq.asyncio.Poller()
poller.register(self.from_controller, zmq.POLLIN)
while True: while True:
# 1) Get request from the Connector. # 1) Get request from the Connector.
req_type, req_data = await self.from_client.recv_multipart() req_type, req_data = await self.from_controller.recv_multipart()
# 2) Handle the request. # 2) Handle the request.
await self._handle_request(req_type, req_data) await self._handle_request(req_type, req_data)
@ -85,11 +89,13 @@ class PDWorker:
2) call the appropriate handler for the request type 2) call the appropriate handler for the request type
""" """
if req_type == PDRequestType.GENERATION: if req_type == PDRequestType.GENERATION:
req = self.decode_generation(req_data) req = self.decode_generation.decode(req_data)
await self._generation_handler(req) await self._generation_handler(req)
elif req_type == PDRequestType.ABORT: elif req_type == PDRequestType.ABORT:
req = self.decode_abort(req_data) req = self.decode_abort.decode(req_data)
await self._abort_handler(req) await self._abort_handler(req)
else:
raise Exception(f"Unknown Request Type: {req_type}.")
async def _generation_handler(self, req: PDGenerationRequest): async def _generation_handler(self, req: PDGenerationRequest):
""" """
@ -133,7 +139,5 @@ class PDWorker:
# 4) Serialize and send to PDConroller. # 4) Serialize and send to PDConroller.
response_bytes = self.encoder.encode(response) response_bytes = self.encoder.encode(response)
msg = [PDGenerationResponse.SUCCE, response_bytes] msg = (PDResponseType.GENERATION, response_bytes)
logger.debug("Sending: %s", request_id) await self.to_controller.send_multipart(msg, copy=False)
await self.to_client.send_multipart(msg, copy=False)
logger.debug("Sent: %s", request_id)

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Toy connector for prototyping. Toy API Server for prototyping.
When PDConroller supports the protocol and we clean up the Once the PDController is more mature and we clean up
OpenAI Server, we can drop this in favor of vllm serve. the OpenAI Server at bit, we can put the PDController
directly inside and launch with vllm serve.
""" """
from collections.abc import AsyncIterator from collections.abc import AsyncIterator

View File

@ -10,12 +10,14 @@ from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger("vllm.entrypoints.disaggregated.worker")
async def run(args, engine: EngineClient): async def run(args, engine: EngineClient):
try: try:
worker = PDWorker(engine, args.worker_addr, args.controller_addr) worker = PDWorker(engine=engine,
worker_addr=args.worker_addr,
controller_addr=args.controller_addr)
await worker.run_busy_loop() await worker.run_busy_loop()
finally: finally:
worker.shutdown() worker.shutdown()
@ -25,7 +27,6 @@ async def main(args) -> None:
logger.info("vLLM P/D Worker Server %s", VLLM_VERSION) logger.info("vLLM P/D Worker Server %s", VLLM_VERSION)
logger.info("Args: %s", args) logger.info("Args: %s", args)
args.disable_frontend_multiprocessing = False
async with build_async_engine_client(args) as engine: async with build_async_engine_client(args) as engine:
await run(args, engine) await run(args, engine)
@ -40,5 +41,8 @@ if __name__ == "__main__":
type=str, type=str,
required=True, required=True,
help='The address of the worker.') help='The address of the worker.')
parser.add_argument('--disable-frontend-multiprocessing',
action="store_true",
help='Disable MQLLMEngine for AsyncLLMEngine.')
AsyncEngineArgs.add_cli_args(parser) AsyncEngineArgs.add_cli_args(parser)
uvloop.run(main(parser.parse_args())) uvloop.run(main(parser.parse_args()))