From 7954461d4cf72df7db092e94f5ec7c9069636565 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 23 Mar 2025 23:03:42 +0000 Subject: [PATCH] updated Signed-off-by: rshaw@neuralmagic.com --- .../disaggregated_prefill_zmq.sh | 12 +- vllm/disaggregated/pd_controller.py | 368 ++++++++++++++++++ vllm/disaggregated/pd_worker.py | 10 +- vllm/entrypoints/disaggregated/api_server.py | 135 +++++++ vllm/entrypoints/disaggregated/worker.py | 6 +- 5 files changed, 517 insertions(+), 14 deletions(-) create mode 100644 vllm/disaggregated/pd_controller.py create mode 100644 vllm/entrypoints/disaggregated/api_server.py diff --git a/examples/online_serving/disaggregated_prefill_zmq.sh b/examples/online_serving/disaggregated_prefill_zmq.sh index f881c4e49a92c..adea0c6929fb0 100644 --- a/examples/online_serving/disaggregated_prefill_zmq.sh +++ b/examples/online_serving/disaggregated_prefill_zmq.sh @@ -50,9 +50,9 @@ PREFILL_WORKER_ADDR=prefillipc DECODE_WORKER_ADDR=decodeipc # prefilling instance, which is the KV producer -CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \ +CUDA_VISIBLE_DEVICES=0 python3 ../../vllm/entrypoints/disaggregated/worker.py \ --model $MODEL \ - --connector-addr $controller_addr \ + --controller-addr $CONTROLLER_ADDR \ --worker-addr $PREFILL_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -60,9 +60,9 @@ CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 & # decoding instance, which is the KV consumer -CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \ +CUDA_VISIBLE_DEVICES=1 python3 ../../vllm/entrypoints/disaggregated/worker.py \ --model $MODEL \ - --connector-addr $controller_addr \ + --controller-addr $CONTROLLER_ADDR \ --worker-addr $DECODE_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -73,10 +73,10 @@ CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \ # the workflow of this proxy: # - Send req to prefill instance, wait until complete. # - Send req to decode instance, streaming tokens. -python3 ../vllm/entrypoints/disaggregated/connector.py \ +python3 ../../vllm/entrypoints/disaggregated/api_server.py \ --port $PORT \ --model $MODEL \ - --connector-addr $controller_addr \ + --controller-addr $CONTROLLER_ADDR \ --prefill-addr $PREFILL_WORKER_ADDR \ --decode-addr $DECODE_WORKER_ADDR diff --git a/vllm/disaggregated/pd_controller.py b/vllm/disaggregated/pd_controller.py new file mode 100644 index 0000000000000..ca6d8fe10781d --- /dev/null +++ b/vllm/disaggregated/pd_controller.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import os +from collections.abc import AsyncGenerator, Mapping +from typing import Optional, Union + +import msgspec +import zmq +import zmq.asyncio + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.disaggregated.protocol import (PDGenerationRequest, + PDGenerationResponse, PDRequestType, + PDResponseType) +from vllm.engine.protocol import EngineClient +from vllm.inputs.data import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import Device + +logger = init_logger(__name__) + +DEFAULT_MAX_TOKENS = 32000 + + +class PDController(EngineClient): + """ + Controller that schedules work on the PDWorkers. + + Conforms for the EngineClient protocol so it can + be wrapped with the OpenAI Server. + + Two Phases: + * Send request to prefill worker, await ack. + * Send request to decode worker, stream responses. + + KVSync happens directly between Engines, + handled by vLLM KVCacheTransfer. + + [ OpenAI Server ] + | + [ PDController ] + | + [ zmq ] + | + [ PDWorker ] [ PDWorker ] + | + [ Engine ] <---> [ Engine ] + + After PR #12957, we will support xPyD, so we will + also need to implement a scheduler and service + discovery for the workers. + + This PDController may be implemented as a K8s + controller. This is intended to be a prototype. + + * TODO: better error handling + * TODO: support logprobs, multimodal, etc. + """ + + def __init__(self, prefill_addr: str, decode_addr: str, + controller_addr: str, model_name: str): + # Request queues. + self.queues: dict[str, asyncio.Queue] = {} + + # Serialization encoder. + self.encoder = msgspec.msgpack.Encoder() + + # ZMQ communication. + # TODO: once https://github.com/vllm-project/vllm/pull/12957 + # lands, do service discovery to scale out workers. + self.ctx = zmq.asyncio.Context() + self.to_decode = self.ctx.socket(zmq.constants.PUSH) + self.to_decode.bind(f"{decode_addr}") + self.to_prefill = self.ctx.socket(zmq.constants.PUSH) + self.to_prefill.bind(f"{prefill_addr}") + self.controller_addr = controller_addr + self.decode_addr = decode_addr + self.prefill_addr = prefill_addr + + # Background loops (started on first generate()). + self.output_handler: Optional[asyncio.Task] = None + self.log_running: Optional[asyncio.Task] = None + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor OAI Server to avoid needing this. + self.model_config = ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="auto", + task="generate", + seed=42) + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor OAI Server to avoid needing this. + self.tokenizer = TokenizerGroup( + **dict(tokenizer_id=self.model_config.tokenizer, + enable_lora=False, + max_num_seqs=1024, + max_loras=0, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision, + truncation_side=self.model_config.truncation_side)) + + def shutdown(self): + if (ctx := self.ctx) is not None: + ctx.destroy(linger=0) + if (task := self.log_running) is not None: + task.cancel() + if (task := self.output_handler) is not None: + task.cancel() + + ipc_paths = [self.controller_addr, self.decode_addr, self.prefill_addr] + for path in ipc_paths: + socket_path = path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + async def _run_log_running(self): + logger.info("Running requests: %d", len(self.queues)) + await asyncio.sleep(10.) + + async def _run_output_handler(self): + """ + Pull responses from Decode + Prefill engines and + distribute back to the generate() tasks. + """ + decoder = msgspec.msgpack.Decoder(PDGenerationResponse) + + socket: Optional[zmq.asyncio.Socket] = None + try: + socket = self.ctx.socket(zmq.constants.PULL) + socket.bind(self.controller_addr) + + while True: + res_type, res_data = await socket.recv_multipart() + if res_type == PDResponseType.FAILURE: + raise Exception("Failure Response from PDWorker.") + elif res_type == PDResponseType.GENERATION: + response = decoder.decode(res_data) + logger.debug("Got Response: %s", response.request_id) + self.queues[response.request_id].put_nowait(response) + else: + raise Exception("Unknown response type.") + except Exception as e: + # TODO: distinguish between fatal and non-fatal errors. + for q in self.queues.values(): + q.put_nowait(e) + raise e + finally: + if socket is not None: + socket.close(linger=0) + + async def _run_prefill( + self, + request: PDGenerationRequest, + q: asyncio.Queue[Union[Exception, PDGenerationResponse]], + ): + # Send request to the prefill instance. + req_bytes = self.encoder.encode(request) + msg = (PDRequestType.GENERATION, req_bytes) + await self.to_prefill.send_multipart(msg, copy=False) + + # Await completion of the prefill. + response = await q.get() + if isinstance(response, Exception): + raise response + logger.debug("Got Decode Response: %s", request.request_id) + + async def _run_decode( + self, + request: PDGenerationRequest, + q: asyncio.Queue[Union[Exception, PDGenerationResponse]], + ) -> AsyncGenerator[PDGenerationResponse]: + # Send request to the decode instance. + req_bytes = self.encoder.encode(request) + msg = (PDRequestType.GENERATION, req_bytes) + await self.to_decode.send_multipart(msg, copy=False) + + # Iterate response queue and yield each response to caller. + finished = False + while not finished: + response = await q.get() + if isinstance(response, Exception): + raise response + logger.debug("Got Decode Response: %s", request.request_id) + finished = response.finish_reason is not None + yield response + + def _to_request_output( + self, + response: PDGenerationResponse, + prompt_token_ids: list[int], + ) -> RequestOutput: + finished = response.finish_reason is not None + return RequestOutput( + request_id=response.request_id, + prompt=None, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=[ + CompletionOutput(index=0, + text=response.text, + token_ids=response.token_ids, + cumulative_logprob=None, + logprobs=None, + finish_reason=response.finish_reason, + stop_reason=response.stop_reason) + ], + finished=finished, + ) + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput]: + # Start loops on first request. + if self.output_handler is None: + self.output_handler = asyncio.create_task( + self._run_output_handler()) + self.log_running = asyncio.create_task(self._run_log_running()) + + # TODO: Expand to support the full matrix. + if "prompt_token_ids" not in prompt: + raise NotImplementedError( + "We currently only support TokensPrompt for P/D!") + if lora_request is not None: + raise NotImplementedError( + "We currently do not support LoRA for P/D!") + if trace_headers is not None: + raise NotImplementedError( + "We currently do not support tracing for P/D!") + if prompt_adapter_request is not None: + raise NotImplementedError( + "We currently do not support prompt adapter for P/D!") + if priority != 0: + raise NotImplementedError( + "We currently do not support priority for P/D!") + if request_id in self.queues: + raise ValueError(f"Found duplicate request_id: {request_id}!") + + # Queue to gather output from output_handler. + q = asyncio.Queue() + self.queues[request_id] = q + + # (1) Perform the Prefill. + original_max_tokens = sampling_params.max_tokens + request = PDGenerationRequest( + request_id=request_id, + prompt_token_ids=prompt["prompt_token_ids"], + sampling_params=sampling_params) + request.sampling_params.max_tokens = 1 + logger.debug("Sending Prefill: %s", request.request_id) + pd_response = await self._run_prefill(request, q) + + # (2) Perform the Decodes. + logger.debug("Sending Decode: %s", request.request_id) + request.sampling_params.max_tokens = original_max_tokens + async for pd_response in self._run_decode(request, q): + yield self._to_request_output(pd_response, + prompt["prompt_token_ids"]) + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + raise NotImplementedError + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + raise NotImplementedError + + async def abort(self, request_id: str) -> None: + raise NotImplementedError + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def get_decoding_config(self) -> DecodingConfig: + raise NotImplementedError + + async def get_input_preprocessor(self) -> InputPreprocessor: + raise NotImplementedError + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if lora_request is not None: + raise NotImplementedError( + "LoRA is not yet supported in the PDEngine.") + return self.tokenizer.get_lora_tokenizer(None) + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[list[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + pass + + async def start_profile(self) -> None: + raise NotImplementedError + + async def stop_profile(self) -> None: + raise NotImplementedError + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + raise NotImplementedError + + async def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up(self) -> None: + raise NotImplementedError + + async def is_sleeping(self) -> bool: + return False + + async def add_lora(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + + @property + def errored(self) -> bool: + return False + + def dead_error(self) -> Exception: + return Exception("PDController has failed.") + + def is_running(self) -> bool: + return True + + def is_stopped(self) -> bool: + return False diff --git a/vllm/disaggregated/pd_worker.py b/vllm/disaggregated/pd_worker.py index 5053ae5fdb82a..99c3302b5b886 100644 --- a/vllm/disaggregated/pd_worker.py +++ b/vllm/disaggregated/pd_worker.py @@ -21,7 +21,7 @@ class PDWorker: self, engine: EngineClient, worker_addr: str, - client_addr: str, + controller_addr: str, ): """ PDWorker @@ -35,12 +35,12 @@ class PDWorker: # ZMQ IPC. self.worker_addr = worker_addr - self.client_addr = client_addr + self.controller_addr = controller_addr self.ctx = zmq.asyncio.Context() self.from_client = self.ctx.socket(zmq.constants.PULL) self.from_client.connect(f"ipc://{self.worker_addr}") self.to_client = self.ctx.socket(zmq.constants.PUSH) - self.to_client.connect(f"ipc://{self.client_addr}") + self.to_client.connect(f"ipc://{self.controller_addr}") self.decode_generation = msgspec.msgpack.Decoder(PDGenerationRequest) self.decode_abort = msgspec.msgpack.Decoder(PDAbortRequest) self.encoder = msgspec.msgpack.Encoder() @@ -56,8 +56,8 @@ class PDWorker: for running_request in self.running_requests: running_request.cancel() - if hasattr(self, "client_addr"): - ipc_paths = [self.worker_addr, self.client_addr] + if hasattr(self, "controller_addr"): + ipc_paths = [self.worker_addr, self.controller_addr] for ipc_path in ipc_paths: socket_path = ipc_path.replace("ipc://", "") if os.path.exists(socket_path): diff --git a/vllm/entrypoints/disaggregated/api_server.py b/vllm/entrypoints/disaggregated/api_server.py new file mode 100644 index 0000000000000..f2f925d6716fb --- /dev/null +++ b/vllm/entrypoints/disaggregated/api_server.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Toy connector for prototyping. + +When PDConroller supports the protocol and we clean up the +OpenAI Server, we can drop this in favor of vllm serve. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import uvicorn +import uvloop +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.disaggregated.pd_controller import PDController +from vllm.entrypoints.openai.protocol import (CompletionRequest, + CompletionResponse, + ErrorResponse) +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser, set_ulimit + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.disaggregated.api_server') + +app = FastAPI() + + +@app.get("/v1/models") +async def show_available_models(raw_request: Request): + handler: OpenAIServingModels = raw_request.app.state.openai_serving_models + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion # noqa: E501 + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@asynccontextmanager +async def controller_ctx(prefill_addr: str, decode_addr: str, + controller_addr: str, + model_name: str) -> AsyncIterator[PDController]: + c = PDController(prefill_addr, decode_addr, controller_addr, model_name) + yield c + c.shutdown() + + +async def main(args, **uvicorn_kwargs): + logger.info("vLLM Disaggregated Connector Start %s %s", args, + uvicorn_kwargs) + + # Avoid dropping requests under high concurrency. + set_ulimit() + + # IPC Paths. + prefill_addr = f"ipc://{args.prefill_addr}" + decode_addr = f"ipc://{args.decode_addr}" + controller_addr = f"ipc://{args.controller_addr}" + + # Start Engine. + async with controller_ctx(prefill_addr=prefill_addr, + decode_addr=decode_addr, + controller_addr=controller_addr, + model_name=args.model) as engine_client: + + # Initialize App State. + model_config = await engine_client.get_model_config() + app.state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=[ + BaseModelPath(name=args.served_model_name or args.model, + model_path=args.model) + ], + ) + app.state.openai_serving_completion = OpenAIServingCompletion( + engine_client=engine_client, + model_config=model_config, + models=app.state.openai_serving_models, + request_logger=None, + ) + + # Run Server. + config = uvicorn.Config(app, host=args.host, port=args.port) + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible P/D Server.") + parser.add_argument("--host", + type=str, + default="0.0.0.0", + help="The host of the HTTP server.") + parser.add_argument("--port", + type=int, + default=8000, + help="The port of the HTTP server.") + parser.add_argument("--model", + type=str, + required=True, + help="The path to the model.") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The served name of the model.") + parser.add_argument("--controller-addr", + type=str, + required=True, + help="The zmq ipc controller address") + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The zmq ipc prefill address") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The zmq ipc decode address") + args = parser.parse_args() + uvloop.run(main(args)) diff --git a/vllm/entrypoints/disaggregated/worker.py b/vllm/entrypoints/disaggregated/worker.py index 348d46db75d8c..041101d62dfab 100644 --- a/vllm/entrypoints/disaggregated/worker.py +++ b/vllm/entrypoints/disaggregated/worker.py @@ -15,7 +15,7 @@ logger = init_logger(__name__) async def run(args, engine: EngineClient): try: - worker = PDWorker(engine, args.worker_addr, args.client_addr) + worker = PDWorker(engine, args.worker_addr, args.controller_addr) await worker.run_busy_loop() finally: worker.shutdown() @@ -32,10 +32,10 @@ async def main(args) -> None: if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument('--client-addr', + parser.add_argument('--controller-addr', type=str, required=True, - help='The address of the connector.') + help='The address of the controller.') parser.add_argument('--worker-addr', type=str, required=True,