Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 22:36:57 +00:00
parent 28d0396ff1
commit 66349c33a1
4 changed files with 119 additions and 81 deletions

View File

@ -3,7 +3,7 @@
import asyncio
import os
from collections.abc import AsyncGenerator, Mapping
from typing import Optional
from typing import Optional, Union
import msgspec
import zmq
@ -11,7 +11,10 @@ import zmq.asyncio
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse
from vllm.disaggregated.protocol import (PDGenerationRequest,
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
@ -30,14 +33,36 @@ logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
class PDClient:
class PDController(EngineClient):
"""
PDEngine:
Equiavlent of AsyncLLM for P/D. Assumes there is
a Prefill and Decode service already running.
Controller that schedules work on the PDWorkers.
* TODO: actually handle errors and failure.
* TODO: support more than just text input.
Conforms for the EngineClient protocol so it can
be wrapped with the OpenAI Server.
Two Phases:
* Send request to prefill worker, await ack.
* Send request to decode worker.
KVSync happens directly between Engines,
handled by vLLM KVCacheTransfer.
[ OpenAI Server ]
|
[ PDController ]
|
[ zmq ]
|
[ PDWorker ] [ PDWorker ]
|
[ Engine ] <---> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler.
we will need to support multiple
* TODO: actually handle errors and failure.
* TODO: support the full API (logprobs, multimodal).
"""
def __init__(self, prefill_addr: str, decode_addr: str,
@ -49,6 +74,8 @@ class PDClient:
self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication.
# TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context()
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.bind(f"{decode_addr}")
@ -63,7 +90,7 @@ class PDClient:
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
# TODO: refactor OAI Server to avoid needing this.
self.model_config = ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
@ -73,7 +100,7 @@ class PDClient:
seed=42)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor EngineClient to avoid needing this.
# TODO: refactor OAI Server to avoid needing this.
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
@ -109,7 +136,7 @@ class PDClient:
Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDResponse)
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
@ -117,72 +144,78 @@ class PDClient:
socket.bind(self.connector_addr)
while True:
reponse_bytes = await socket.recv()
response = decoder.decode(reponse_bytes)
logger.debug("Got Response: %s", response.request_id)
self.queues[response.request_id].put_nowait(response)
except:
# TODO: actually handle failure and shutdown.
raise
res_type, res_data = await socket.recv_multipart()
if res_type == PDResponseType.FAILURE:
raise Exception("Failure Response from PDWorker.")
elif res_type == PDResponseType.GENERATION:
response = decoder.decode(res_data)
logger.debug("Got Response: %s", response.request_id)
self.queues[response.request_id].put_nowait(response)
else:
raise Exception("Unknown response type.")
except Exception as e:
# TODO: distinguish between fatal and non-fatal errors.
for _, q in self.queues.values():
q.put_nowait(e)
raise e
finally:
if socket is not None:
socket.close(linger=0)
async def _prefill(
async def _run_prefill(
self,
request: PDRequest,
q: asyncio.Queue[PDResponse],
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
):
# Send request to the prefill instance.
req_bytes = self.encoder.encode(request)
await self.to_prefill.send(req_bytes, copy=False)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_prefill.send_multipart(msg, copy=False)
# Wait for the prefill to be done.
response = await q.get()
assert response.request_id == request.request_id
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Prefill Request.")
if isinstance(response, Exception):
raise response
async def _decode(
async def _run_decode(
self,
request: PDRequest,
q: asyncio.Queue[PDResponse],
) -> AsyncGenerator[PDRequest]:
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
) -> AsyncGenerator[PDGenerationResponse]:
# Send request to the decode instance.
req_bytes = self.encoder.encode(request)
await self.to_decode.send(req_bytes, copy=False)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_decode.send_multipart(msg, copy=False)
# Iterate response queue and yield each response to caller..
# Iterate response queue and yield each response to caller.
finished = False
while not finished:
response = await q.get()
if not response.success:
# TODO: actual error handling and shutdown.
raise Exception("Failed Decode Request.")
if isinstance(response, Exception):
raise response
finished = response.finish_reason is not None
yield response
def _to_request_output(
self,
pd_response: PDResponse,
response: PDGenerationResponse,
prompt_token_ids: list[int],
) -> RequestOutput:
finished = pd_response.finish_reason is not None
finished = response.finish_reason is not None
return RequestOutput(
request_id=pd_response.request_id,
request_id=response.request_id,
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
outputs=[
CompletionOutput(index=0,
text=pd_response.text,
token_ids=pd_response.token_ids,
text=response.text,
token_ids=response.token_ids,
cumulative_logprob=None,
logprobs=None,
finish_reason=pd_response.finish_reason,
stop_reason=pd_response.stop_reason)
finish_reason=response.finish_reason,
stop_reason=response.stop_reason)
],
finished=finished,
)
@ -223,24 +256,24 @@ class PDClient:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q: asyncio.Queue[PDResponse] = asyncio.Queue()
q = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
prompt_token_ids = prompt["prompt_token_ids"]
request = PDRequest(request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
request = PDGenerationRequest(request_id=request_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._prefill(request, q)
pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes.
logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._decode(request, q):
logger.debug("Got Decode: %s", request.request_id)
async for pd_response in self._run_decode(request, q):
logger.debug("Got Response: %s", request.request_id)
yield self._to_request_output(pd_response, prompt_token_ids)
async def beam_search(

View File

@ -65,9 +65,9 @@ class PDWorker:
async def run_busy_loop(self):
"""
Main execution loop for the PDWorker:
* 1) Wait for a request from the PDClient.
* 2) Handle the request.
main loop:
1) wait for a request from the PDClient
2) handle the request
"""
logger.info("PDWorker is ready To handle requests.")
@ -76,10 +76,14 @@ class PDWorker:
req_type, req_data = await self.from_client.recv_multipart()
# 2) Handle the request.
await self._handle_request(req_type, req_data.buffer)
await self._handle_request(req_type, req_data)
async def _handle_request(self, req_type: bytes, req_data: bytes):
"""Parse the request type and call the appropriate handler."""
"""
request handler:
1) parse the request type
2) call the appropriate handler for the request type
"""
if req_type == PDRequestType.GENERATION:
req = self.decode_generation(req_data)
await self._generation_handler(req)
@ -88,13 +92,18 @@ class PDWorker:
await self._abort_handler(req)
async def _generation_handler(self, req: PDGenerationRequest):
"""Launch generation in a background task."""
"""
Handle a PDGenerationRequest by launching a task.
"""
task = asyncio.create_task(self._generate(req))
self.running_requests.add(task)
task.add_done_callback(self.running_requests.discard)
async def _abort_handler(self, req: PDGenerationRequest):
"""Abort the request in the engine."""
"""
Handle a PDAbortRequest by cancelling the running task.
The _generate coro aborts in the Engine.
"""
# Convert running_requests set() into a dict(), keyed
# by request_id. Cancel the task when an abort comes in.
# Then update the _generate coroutine to handle a
@ -103,11 +112,11 @@ class PDWorker:
async def _generate(self, req: PDGenerationRequest):
"""
Handle a single PDRequest:
Handle a single PDGenerationRequest:
* 1) submit request to AsyncLLM
* 2) iterate the RequestOutputs
* 3) convert RequestOutput --> PDResponse
* 4) serialize and send to Connector.
* 4) serialize and send to PDClient
"""
request_id = req.request_id

View File

@ -55,9 +55,3 @@ class PDGenerationResponse(msgspec.Struct):
finish_reason=out.finish_reason,
stop_reason=out.stop_reason,
)
class PDGenerationFailure(msgspec.Struct):
request_id: str
error_message: str
engine_dead: bool

View File

@ -1,5 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
"""Toy Server For Prototyping."""
"""
Toy connector for prototyping.
When PDClient supports the protocol and we clean up the
OpenAI Server, we can drop this in favor of vllm serve.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
@ -9,7 +14,7 @@ import uvloop
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.disaggregated.engine import PDEngine
from vllm.disaggregated.pd_client import PDClient
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
@ -46,34 +51,31 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@asynccontextmanager
async def pd_engine_client_ctx_manager(
prefill_addr: str, decode_addr: str, connector_addr: str,
model_name: str) -> AsyncIterator[PDEngine]:
engine = PDEngine(prefill_addr, decode_addr, connector_addr, model_name)
yield engine
engine.shutdown()
async def pd_client_ctx(prefill_addr: str, decode_addr: str,
connector_addr: str,
model_name: str) -> AsyncIterator[PDClient]:
client = PDClient(prefill_addr, decode_addr, connector_addr, model_name)
yield client
client.shutdown()
async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector Start %s %s", args,
logger.info("vLLM Disaggregated Connector Start %s %s", args,
uvicorn_kwargs)
# Avoid dropping requests under high concurrency.
set_ulimit()
# IPC Paths.
# NOTE FOR DEVELOPERS: when shifting to TCP, ensure you
# are not using pickle to avoid RCE security flaw.
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
connector_addr = f"ipc://{args.connector_addr}"
# Start Engine.
async with pd_engine_client_ctx_manager(
prefill_addr=prefill_addr,
decode_addr=decode_addr,
connector_addr=connector_addr,
model_name=args.model) as engine_client:
async with pd_client_ctx(prefill_addr=prefill_addr,
decode_addr=decode_addr,
connector_addr=connector_addr,
model_name=args.model) as engine_client:
# Initialize App State.
model_config = await engine_client.get_model_config()
@ -93,7 +95,7 @@ async def main(args, **uvicorn_kwargs):
)
# Run Server.
config = uvicorn.Config(app, host="0.0.0.0", port=args.port)
config = uvicorn.Config(app, host=args.host, port=args.port)
server = uvicorn.Server(config)
await server.serve()