Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 22:36:57 +00:00
parent 28d0396ff1
commit 66349c33a1
4 changed files with 119 additions and 81 deletions

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from typing import Optional from typing import Optional, Union
import msgspec import msgspec
import zmq import zmq
@ -11,7 +11,10 @@ import zmq.asyncio
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse from vllm.disaggregated.protocol import (PDGenerationRequest,
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
@ -30,14 +33,36 @@ logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000 DEFAULT_MAX_TOKENS = 32000
class PDClient: class PDController(EngineClient):
""" """
PDEngine: Controller that schedules work on the PDWorkers.
Equiavlent of AsyncLLM for P/D. Assumes there is
a Prefill and Decode service already running.
* TODO: actually handle errors and failure. Conforms for the EngineClient protocol so it can
* TODO: support more than just text input. be wrapped with the OpenAI Server.
Two Phases:
* Send request to prefill worker, await ack.
* Send request to decode worker.
KVSync happens directly between Engines,
handled by vLLM KVCacheTransfer.
[ OpenAI Server ]
|
[ PDController ]
|
[ zmq ]
|
[ PDWorker ] [ PDWorker ]
|
[ Engine ] <---> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler.
we will need to support multiple
* TODO: actually handle errors and failure.
* TODO: support the full API (logprobs, multimodal).
""" """
def __init__(self, prefill_addr: str, decode_addr: str, def __init__(self, prefill_addr: str, decode_addr: str,
@ -49,6 +74,8 @@ class PDClient:
self.encoder = msgspec.msgpack.Encoder() self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication. # ZMQ communication.
# TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context() self.ctx = zmq.asyncio.Context()
self.to_decode = self.ctx.socket(zmq.constants.PUSH) self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.bind(f"{decode_addr}") self.to_decode.bind(f"{decode_addr}")
@ -63,7 +90,7 @@ class PDClient:
self.log_running: Optional[asyncio.Task] = None self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor OAI Server to avoid needing this.
self.model_config = ModelConfig(model=model_name, self.model_config = ModelConfig(model=model_name,
tokenizer=model_name, tokenizer=model_name,
tokenizer_mode="auto", tokenizer_mode="auto",
@ -73,7 +100,7 @@ class PDClient:
seed=42) seed=42)
# Dummy: needed for EngineClient Protocol. # Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this. # TODO: refactor OAI Server to avoid needing this.
init_kwargs = dict( init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer, tokenizer_id=self.model_config.tokenizer,
enable_lora=False, enable_lora=False,
@ -109,7 +136,7 @@ class PDClient:
Pull responses from Decode + Prefill engines and Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks. distribute back to the generate() tasks.
""" """
decoder = msgspec.msgpack.Decoder(PDResponse) decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
socket: Optional[zmq.asyncio.Socket] = None socket: Optional[zmq.asyncio.Socket] = None
try: try:
@ -117,72 +144,78 @@ class PDClient:
socket.bind(self.connector_addr) socket.bind(self.connector_addr)
while True: while True:
reponse_bytes = await socket.recv() res_type, res_data = await socket.recv_multipart()
response = decoder.decode(reponse_bytes) if res_type == PDResponseType.FAILURE:
logger.debug("Got Response: %s", response.request_id) raise Exception("Failure Response from PDWorker.")
self.queues[response.request_id].put_nowait(response) elif res_type == PDResponseType.GENERATION:
except: response = decoder.decode(res_data)
# TODO: actually handle failure and shutdown. logger.debug("Got Response: %s", response.request_id)
raise self.queues[response.request_id].put_nowait(response)
else:
raise Exception("Unknown response type.")
except Exception as e:
# TODO: distinguish between fatal and non-fatal errors.
for _, q in self.queues.values():
q.put_nowait(e)
raise e
finally: finally:
if socket is not None: if socket is not None:
socket.close(linger=0) socket.close(linger=0)
async def _prefill( async def _run_prefill(
self, self,
request: PDRequest, request: PDGenerationRequest,
q: asyncio.Queue[PDResponse], q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
): ):
# Send request to the prefill instance. # Send request to the prefill instance.
req_bytes = self.encoder.encode(request) req_bytes = self.encoder.encode(request)
await self.to_prefill.send(req_bytes, copy=False) msg = (PDRequestType.GENERATION, req_bytes)
await self.to_prefill.send_multipart(msg, copy=False)
# Wait for the prefill to be done. # Wait for the prefill to be done.
response = await q.get() response = await q.get()
assert response.request_id == request.request_id if isinstance(response, Exception):
if not response.success: raise response
# TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.")
async def _decode( async def _run_decode(
self, self,
request: PDRequest, request: PDGenerationRequest,
q: asyncio.Queue[PDResponse], q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
) -> AsyncGenerator[PDRequest]: ) -> AsyncGenerator[PDGenerationResponse]:
# Send request to the decode instance. # Send request to the decode instance.
req_bytes = self.encoder.encode(request) req_bytes = self.encoder.encode(request)
await self.to_decode.send(req_bytes, copy=False) msg = (PDRequestType.GENERATION, req_bytes)
await self.to_decode.send_multipart(msg, copy=False)
# Iterate response queue and yield each response to caller.. # Iterate response queue and yield each response to caller.
finished = False finished = False
while not finished: while not finished:
response = await q.get() response = await q.get()
if not response.success: if isinstance(response, Exception):
# TODO: actual error handling and shutdown. raise response
raise Exception("Failed Decode Request.")
finished = response.finish_reason is not None finished = response.finish_reason is not None
yield response yield response
def _to_request_output( def _to_request_output(
self, self,
pd_response: PDResponse, response: PDGenerationResponse,
prompt_token_ids: list[int], prompt_token_ids: list[int],
) -> RequestOutput: ) -> RequestOutput:
finished = pd_response.finish_reason is not None finished = response.finish_reason is not None
return RequestOutput( return RequestOutput(
request_id=pd_response.request_id, request_id=response.request_id,
prompt=None, prompt=None,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, prompt_logprobs=None,
outputs=[ outputs=[
CompletionOutput(index=0, CompletionOutput(index=0,
text=pd_response.text, text=response.text,
token_ids=pd_response.token_ids, token_ids=response.token_ids,
cumulative_logprob=None, cumulative_logprob=None,
logprobs=None, logprobs=None,
finish_reason=pd_response.finish_reason, finish_reason=response.finish_reason,
stop_reason=pd_response.stop_reason) stop_reason=response.stop_reason)
], ],
finished=finished, finished=finished,
) )
@ -223,24 +256,24 @@ class PDClient:
raise ValueError(f"Found duplicate request_id: {request_id}!") raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler. # Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue() q = asyncio.Queue()
self.queues[request_id] = q self.queues[request_id] = q
# (1) Perform the Prefill. # (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"] prompt_token_ids = prompt["prompt_token_ids"]
request = PDRequest(request_id=request_id, request = PDGenerationRequest(request_id=request_id,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params) sampling_params=sampling_params)
request.sampling_params.max_tokens = 1 request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id) logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._prefill(request, q) pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes. # (2) Perform the Decodes.
logger.debug("Sending Decode: %s", request.request_id) logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._decode(request, q): async for pd_response in self._run_decode(request, q):
logger.debug("Got Decode: %s", request.request_id) logger.debug("Got Response: %s", request.request_id)
yield self._to_request_output(pd_response, prompt_token_ids) yield self._to_request_output(pd_response, prompt_token_ids)
async def beam_search( async def beam_search(

View File

@ -65,9 +65,9 @@ class PDWorker:
async def run_busy_loop(self): async def run_busy_loop(self):
""" """
Main execution loop for the PDWorker: main loop:
* 1) Wait for a request from the PDClient. 1) wait for a request from the PDClient
* 2) Handle the request. 2) handle the request
""" """
logger.info("PDWorker is ready To handle requests.") logger.info("PDWorker is ready To handle requests.")
@ -76,10 +76,14 @@ class PDWorker:
req_type, req_data = await self.from_client.recv_multipart() req_type, req_data = await self.from_client.recv_multipart()
# 2) Handle the request. # 2) Handle the request.
await self._handle_request(req_type, req_data.buffer) await self._handle_request(req_type, req_data)
async def _handle_request(self, req_type: bytes, req_data: bytes): async def _handle_request(self, req_type: bytes, req_data: bytes):
"""Parse the request type and call the appropriate handler.""" """
request handler:
1) parse the request type
2) call the appropriate handler for the request type
"""
if req_type == PDRequestType.GENERATION: if req_type == PDRequestType.GENERATION:
req = self.decode_generation(req_data) req = self.decode_generation(req_data)
await self._generation_handler(req) await self._generation_handler(req)
@ -88,13 +92,18 @@ class PDWorker:
await self._abort_handler(req) await self._abort_handler(req)
async def _generation_handler(self, req: PDGenerationRequest): async def _generation_handler(self, req: PDGenerationRequest):
"""Launch generation in a background task.""" """
Handle a PDGenerationRequest by launching a task.
"""
task = asyncio.create_task(self._generate(req)) task = asyncio.create_task(self._generate(req))
self.running_requests.add(task) self.running_requests.add(task)
task.add_done_callback(self.running_requests.discard) task.add_done_callback(self.running_requests.discard)
async def _abort_handler(self, req: PDGenerationRequest): async def _abort_handler(self, req: PDGenerationRequest):
"""Abort the request in the engine.""" """
Handle a PDAbortRequest by cancelling the running task.
The _generate coro aborts in the Engine.
"""
# Convert running_requests set() into a dict(), keyed # Convert running_requests set() into a dict(), keyed
# by request_id. Cancel the task when an abort comes in. # by request_id. Cancel the task when an abort comes in.
# Then update the _generate coroutine to handle a # Then update the _generate coroutine to handle a
@ -103,11 +112,11 @@ class PDWorker:
async def _generate(self, req: PDGenerationRequest): async def _generate(self, req: PDGenerationRequest):
""" """
Handle a single PDRequest: Handle a single PDGenerationRequest:
* 1) submit request to AsyncLLM * 1) submit request to AsyncLLM
* 2) iterate the RequestOutputs * 2) iterate the RequestOutputs
* 3) convert RequestOutput --> PDResponse * 3) convert RequestOutput --> PDResponse
* 4) serialize and send to Connector. * 4) serialize and send to PDClient
""" """
request_id = req.request_id request_id = req.request_id

View File

@ -55,9 +55,3 @@ class PDGenerationResponse(msgspec.Struct):
finish_reason=out.finish_reason, finish_reason=out.finish_reason,
stop_reason=out.stop_reason, stop_reason=out.stop_reason,
) )
class PDGenerationFailure(msgspec.Struct):
request_id: str
error_message: str
engine_dead: bool

View File

@ -1,5 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Toy Server For Prototyping.""" """
Toy connector for prototyping.
When PDClient supports the protocol and we clean up the
OpenAI Server, we can drop this in favor of vllm serve.
"""
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -9,7 +14,7 @@ import uvloop
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.disaggregated.engine import PDEngine from vllm.disaggregated.pd_client import PDClient
from vllm.entrypoints.openai.protocol import (CompletionRequest, from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse, CompletionResponse,
ErrorResponse) ErrorResponse)
@ -46,34 +51,31 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@asynccontextmanager @asynccontextmanager
async def pd_engine_client_ctx_manager( async def pd_client_ctx(prefill_addr: str, decode_addr: str,
prefill_addr: str, decode_addr: str, connector_addr: str, connector_addr: str,
model_name: str) -> AsyncIterator[PDEngine]: model_name: str) -> AsyncIterator[PDClient]:
engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name) client = PDClient(prefill_addr, decode_addr, connector_addr, model_name)
yield engine yield client
engine.shutdown() client.shutdown()
async def main(args, **uvicorn_kwargs): async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector Start %s %s", args, logger.info("vLLM Disaggregated Connector Start %s %s", args,
uvicorn_kwargs) uvicorn_kwargs)
# Avoid dropping requests under high concurrency. # Avoid dropping requests under high concurrency.
set_ulimit() set_ulimit()
# IPC Paths. # IPC Paths.
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
# are not using pickle to avoid RCE security flaw.
prefill_addr = f"ipc://{args.prefill_addr}" prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}" decode_addr = f"ipc://{args.decode_addr}"
connector_addr = f"ipc://{args.connector_addr}" connector_addr = f"ipc://{args.connector_addr}"
# Start Engine. # Start Engine.
async with pd_engine_client_ctx_manager( async with pd_client_ctx(prefill_addr=prefill_addr,
prefill_addr=prefill_addr, decode_addr=decode_addr,
decode_addr=decode_addr, connector_addr=connector_addr,
connector_addr=connector_addr, model_name=args.model) as engine_client:
model_name=args.model) as engine_client:
# Initialize App State. # Initialize App State.
model_config = await engine_client.get_model_config() model_config = await engine_client.get_model_config()
@ -93,7 +95,7 @@ async def main(args, **uvicorn_kwargs):
) )
# Run Server. # Run Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port) config = uvicorn.Config(app, host=args.host, port=args.port)
server = uvicorn.Server(config) server = uvicorn.Server(config)
await server.serve() await server.serve()