Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com 2025-03-23 23:03:42 +00:00
parent a10da86677
commit 7954461d4c
5 changed files with 517 additions and 14 deletions

View File

@ -50,9 +50,9 @@ PREFILL_WORKER_ADDR=prefillipc
DECODE_WORKER_ADDR=decodeipc
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \
CUDA_VISIBLE_DEVICES=0 python3 ../../vllm/entrypoints/disaggregated/worker.py \
--model $MODEL \
--connector-addr $controller_addr \
--controller-addr $CONTROLLER_ADDR \
--worker-addr $PREFILL_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
@ -60,9 +60,9 @@ CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \
CUDA_VISIBLE_DEVICES=1 python3 ../../vllm/entrypoints/disaggregated/worker.py \
--model $MODEL \
--connector-addr $controller_addr \
--controller-addr $CONTROLLER_ADDR \
--worker-addr $DECODE_WORKER_ADDR \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
@ -73,10 +73,10 @@ CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \
# the workflow of this proxy:
# - Send req to prefill instance, wait until complete.
# - Send req to decode instance, streaming tokens.
python3 ../vllm/entrypoints/disaggregated/connector.py \
python3 ../../vllm/entrypoints/disaggregated/api_server.py \
--port $PORT \
--model $MODEL \
--connector-addr $controller_addr \
--controller-addr $CONTROLLER_ADDR \
--prefill-addr $PREFILL_WORKER_ADDR \
--decode-addr $DECODE_WORKER_ADDR

View File

@ -0,0 +1,368 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from collections.abc import AsyncGenerator, Mapping
from typing import Optional, Union
import msgspec
import zmq
import zmq.asyncio
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.disaggregated.protocol import (PDGenerationRequest,
PDGenerationResponse, PDRequestType,
PDResponseType)
from vllm.engine.protocol import EngineClient
from vllm.inputs.data import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.utils import Device
logger = init_logger(__name__)
DEFAULT_MAX_TOKENS = 32000
class PDController(EngineClient):
"""
Controller that schedules work on the PDWorkers.
Conforms for the EngineClient protocol so it can
be wrapped with the OpenAI Server.
Two Phases:
* Send request to prefill worker, await ack.
* Send request to decode worker, stream responses.
KVSync happens directly between Engines,
handled by vLLM KVCacheTransfer.
[ OpenAI Server ]
|
[ PDController ]
|
[ zmq ]
|
[ PDWorker ] [ PDWorker ]
|
[ Engine ] <---> [ Engine ]
After PR #12957, we will support xPyD, so we will
also need to implement a scheduler and service
discovery for the workers.
This PDController may be implemented as a K8s
controller. This is intended to be a prototype.
* TODO: better error handling
* TODO: support logprobs, multimodal, etc.
"""
def __init__(self, prefill_addr: str, decode_addr: str,
controller_addr: str, model_name: str):
# Request queues.
self.queues: dict[str, asyncio.Queue] = {}
# Serialization encoder.
self.encoder = msgspec.msgpack.Encoder()
# ZMQ communication.
# TODO: once https://github.com/vllm-project/vllm/pull/12957
# lands, do service discovery to scale out workers.
self.ctx = zmq.asyncio.Context()
self.to_decode = self.ctx.socket(zmq.constants.PUSH)
self.to_decode.bind(f"{decode_addr}")
self.to_prefill = self.ctx.socket(zmq.constants.PUSH)
self.to_prefill.bind(f"{prefill_addr}")
self.controller_addr = controller_addr
self.decode_addr = decode_addr
self.prefill_addr = prefill_addr
# Background loops (started on first generate()).
self.output_handler: Optional[asyncio.Task] = None
self.log_running: Optional[asyncio.Task] = None
# Dummy: needed for EngineClient Protocol.
# TODO: refactor OAI Server to avoid needing this.
self.model_config = ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="auto",
task="generate",
seed=42)
# Dummy: needed for EngineClient Protocol.
# TODO: refactor OAI Server to avoid needing this.
self.tokenizer = TokenizerGroup(
**dict(tokenizer_id=self.model_config.tokenizer,
enable_lora=False,
max_num_seqs=1024,
max_loras=0,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision,
truncation_side=self.model_config.truncation_side))
def shutdown(self):
if (ctx := self.ctx) is not None:
ctx.destroy(linger=0)
if (task := self.log_running) is not None:
task.cancel()
if (task := self.output_handler) is not None:
task.cancel()
ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr]
for path in ipc_paths:
socket_path = path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
async def _run_log_running(self):
logger.info("Running requests: %d", len(self.queues))
await asyncio.sleep(10.)
async def _run_output_handler(self):
"""
Pull responses from Decode + Prefill engines and
distribute back to the generate() tasks.
"""
decoder = msgspec.msgpack.Decoder(PDGenerationResponse)
socket: Optional[zmq.asyncio.Socket] = None
try:
socket = self.ctx.socket(zmq.constants.PULL)
socket.bind(self.controller_addr)
while True:
res_type, res_data = await socket.recv_multipart()
if res_type == PDResponseType.FAILURE:
raise Exception("Failure Response from PDWorker.")
elif res_type == PDResponseType.GENERATION:
response = decoder.decode(res_data)
logger.debug("Got Response: %s", response.request_id)
self.queues[response.request_id].put_nowait(response)
else:
raise Exception("Unknown response type.")
except Exception as e:
# TODO: distinguish between fatal and non-fatal errors.
for q in self.queues.values():
q.put_nowait(e)
raise e
finally:
if socket is not None:
socket.close(linger=0)
async def _run_prefill(
self,
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
):
# Send request to the prefill instance.
req_bytes = self.encoder.encode(request)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_prefill.send_multipart(msg, copy=False)
# Await completion of the prefill.
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
async def _run_decode(
self,
request: PDGenerationRequest,
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
) -> AsyncGenerator[PDGenerationResponse]:
# Send request to the decode instance.
req_bytes = self.encoder.encode(request)
msg = (PDRequestType.GENERATION, req_bytes)
await self.to_decode.send_multipart(msg, copy=False)
# Iterate response queue and yield each response to caller.
finished = False
while not finished:
response = await q.get()
if isinstance(response, Exception):
raise response
logger.debug("Got Decode Response: %s", request.request_id)
finished = response.finish_reason is not None
yield response
def _to_request_output(
self,
response: PDGenerationResponse,
prompt_token_ids: list[int],
) -> RequestOutput:
finished = response.finish_reason is not None
return RequestOutput(
request_id=response.request_id,
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
outputs=[
CompletionOutput(index=0,
text=response.text,
token_ids=response.token_ids,
cumulative_logprob=None,
logprobs=None,
finish_reason=response.finish_reason,
stop_reason=response.stop_reason)
],
finished=finished,
)
async def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput]:
# Start loops on first request.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self.log_running = asyncio.create_task(self._run_log_running())
# TODO: Expand to support the full matrix.
if "prompt_token_ids" not in prompt:
raise NotImplementedError(
"We currently only support TokensPrompt for P/D!")
if lora_request is not None:
raise NotImplementedError(
"We currently do not support LoRA for P/D!")
if trace_headers is not None:
raise NotImplementedError(
"We currently do not support tracing for P/D!")
if prompt_adapter_request is not None:
raise NotImplementedError(
"We currently do not support prompt adapter for P/D!")
if priority != 0:
raise NotImplementedError(
"We currently do not support priority for P/D!")
if request_id in self.queues:
raise ValueError(f"Found duplicate request_id: {request_id}!")
# Queue to gather output from output_handler.
q = asyncio.Queue()
self.queues[request_id] = q
# (1) Perform the Prefill.
original_max_tokens = sampling_params.max_tokens
request = PDGenerationRequest(
request_id=request_id,
prompt_token_ids=prompt["prompt_token_ids"],
sampling_params=sampling_params)
request.sampling_params.max_tokens = 1
logger.debug("Sending Prefill: %s", request.request_id)
pd_response = await self._run_prefill(request, q)
# (2) Perform the Decodes.
logger.debug("Sending Decode: %s", request.request_id)
request.sampling_params.max_tokens = original_max_tokens
async for pd_response in self._run_decode(request, q):
yield self._to_request_output(pd_response,
prompt["prompt_token_ids"])
async def beam_search(
self,
prompt: PromptType,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
raise NotImplementedError
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[PoolingRequestOutput, None]:
raise NotImplementedError
async def abort(self, request_id: str) -> None:
raise NotImplementedError
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def get_decoding_config(self) -> DecodingConfig:
raise NotImplementedError
async def get_input_preprocessor(self) -> InputPreprocessor:
raise NotImplementedError
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if lora_request is not None:
raise NotImplementedError(
"LoRA is not yet supported in the PDEngine.")
return self.tokenizer.get_lora_tokenizer(None)
async def is_tracing_enabled(self) -> bool:
return False
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[list[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
pass
async def start_profile(self) -> None:
raise NotImplementedError
async def stop_profile(self) -> None:
raise NotImplementedError
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
raise NotImplementedError
async def sleep(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up(self) -> None:
raise NotImplementedError
async def is_sleeping(self) -> bool:
return False
async def add_lora(self, lora_request: LoRARequest) -> None:
raise NotImplementedError
@property
def errored(self) -> bool:
return False
def dead_error(self) -> Exception:
return Exception("PDController has failed.")
def is_running(self) -> bool:
return True
def is_stopped(self) -> bool:
return False

View File

@ -21,7 +21,7 @@ class PDWorker:
self,
engine: EngineClient,
worker_addr: str,
client_addr: str,
controller_addr: str,
):
"""
PDWorker
@ -35,12 +35,12 @@ class PDWorker:
# ZMQ IPC.
self.worker_addr = worker_addr
self.client_addr = client_addr
self.controller_addr = controller_addr
self.ctx = zmq.asyncio.Context()
self.from_client = self.ctx.socket(zmq.constants.PULL)
self.from_client.connect(f"ipc://{self.worker_addr}")
self.to_client = self.ctx.socket(zmq.constants.PUSH)
self.to_client.connect(f"ipc://{self.client_addr}")
self.to_client.connect(f"ipc://{self.controller_addr}")
self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest)
self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest)
self.encoder = msgspec.msgpack.Encoder()
@ -56,8 +56,8 @@ class PDWorker:
for running_request in self.running_requests:
running_request.cancel()
if hasattr(self, "client_addr"):
ipc_paths = [self.worker_addr, self.client_addr]
if hasattr(self, "controller_addr"):
ipc_paths = [self.worker_addr, self.controller_addr]
for ipc_path in ipc_paths:
socket_path = ipc_path.replace("ipc://", "")
if os.path.exists(socket_path):

View File

@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
"""
Toy connector for prototyping.
When PDConroller supports the protocol and we clean up the
OpenAI Server, we can drop this in favor of vllm serve.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.disaggregated.pd_controller import PDController
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser, set_ulimit
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disaggregated.api_server')
app = FastAPI()
@app.get("/v1/models")
async def show_available_models(raw_request: Request):
handler: OpenAIServingModels = raw_request.app.state.openai_serving_models
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion # noqa: E501
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@asynccontextmanager
async def controller_ctx(prefill_addr: str, decode_addr: str,
controller_addr: str,
model_name: str) -> AsyncIterator[PDController]:
c = PDController(prefill_addr, decode_addr, controller_addr, model_name)
yield c
c.shutdown()
async def main(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregated Connector Start %s %s", args,
uvicorn_kwargs)
# Avoid dropping requests under high concurrency.
set_ulimit()
# IPC Paths.
prefill_addr = f"ipc://{args.prefill_addr}"
decode_addr = f"ipc://{args.decode_addr}"
controller_addr = f"ipc://{args.controller_addr}"
# Start Engine.
async with controller_ctx(prefill_addr=prefill_addr,
decode_addr=decode_addr,
controller_addr=controller_addr,
model_name=args.model) as engine_client:
# Initialize App State.
model_config = await engine_client.get_model_config()
app.state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=[
BaseModelPath(name=args.served_model_name or args.model,
model_path=args.model)
],
)
app.state.openai_serving_completion = OpenAIServingCompletion(
engine_client=engine_client,
model_config=model_config,
models=app.state.openai_serving_models,
request_logger=None,
)
# Run Server.
config = uvicorn.Config(app, host=args.host, port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible P/D Server.")
parser.add_argument("--host",
type=str,
default="0.0.0.0",
help="The host of the HTTP server.")
parser.add_argument("--port",
type=int,
default=8000,
help="The port of the HTTP server.")
parser.add_argument("--model",
type=str,
required=True,
help="The path to the model.")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The served name of the model.")
parser.add_argument("--controller-addr",
type=str,
required=True,
help="The zmq ipc controller address")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
args = parser.parse_args()
uvloop.run(main(args))

View File

@ -15,7 +15,7 @@ logger = init_logger(__name__)
async def run(args, engine: EngineClient):
try:
worker = PDWorker(engine, args.worker_addr, args.client_addr)
worker = PDWorker(engine, args.worker_addr, args.controller_addr)
await worker.run_busy_loop()
finally:
worker.shutdown()
@ -32,10 +32,10 @@ async def main(args) -> None:
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument('--client-addr',
parser.add_argument('--controller-addr',
type=str,
required=True,
help='The address of the connector.')
help='The address of the controller.')
parser.add_argument('--worker-addr',
type=str,
required=True,