From 284d5df45b7f8b9aa0b541481e0191c7f387ddda Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 23 Mar 2025 22:50:20 +0000 Subject: [PATCH] added __init__.py Signed-off-by: rshaw@neuralmagic.com --- .../disaggregated_prefill_zmq.sh | 8 +- vllm/disaggregated/pd_client.py | 359 ------------------ vllm/disaggregated/pd_worker.py | 13 +- vllm/disaggregated/protocol.py | 8 +- vllm/entrypoints/disaggregated/connector.py | 22 +- 5 files changed, 26 insertions(+), 384 deletions(-) delete mode 100644 vllm/disaggregated/pd_client.py diff --git a/examples/online_serving/disaggregated_prefill_zmq.sh b/examples/online_serving/disaggregated_prefill_zmq.sh index 382d6d2269088..1982fce556884 100644 --- a/examples/online_serving/disaggregated_prefill_zmq.sh +++ b/examples/online_serving/disaggregated_prefill_zmq.sh @@ -45,14 +45,14 @@ wait_for_disagg_server() { # You can also adjust --kv-ip and --kv-port for distributed inference. MODEL=meta-llama/Llama-3.1-8B-Instruct -CONNECTOR_ADDR=connectoripc +contoller_addr=connectoripc PREFILL_WORKER_ADDR=prefillipc DECODE_WORKER_ADDR=decodeipc # prefilling instance, which is the KV producer CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \ --model $MODEL \ - --connector-addr $CONNECTOR_ADDR \ + --connector-addr $contoller_addr \ --worker-addr $PREFILL_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -62,7 +62,7 @@ CUDA_VISIBLE_DEVICES=0 python3 ../vllm/entrypoints/disaggregated/worker.py \ # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \ --model $MODEL \ - --connector-addr $CONNECTOR_ADDR \ + --connector-addr $contoller_addr \ --worker-addr $DECODE_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ @@ -76,7 +76,7 @@ CUDA_VISIBLE_DEVICES=1 python3 ../vllm/entrypoints/disaggregated/worker.py \ python3 ../vllm/entrypoints/disaggregated/connector.py \ --port $PORT \ --model $MODEL \ - --connector-addr $CONNECTOR_ADDR \ + --connector-addr $contoller_addr \ --prefill-addr $PREFILL_WORKER_ADDR \ --decode-addr $DECODE_WORKER_ADDR diff --git a/vllm/disaggregated/pd_client.py b/vllm/disaggregated/pd_client.py deleted file mode 100644 index 15f41e84e8860..0000000000000 --- a/vllm/disaggregated/pd_client.py +++ /dev/null @@ -1,359 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import os -from collections.abc import AsyncGenerator, Mapping -from typing import Optional, Union - -import msgspec -import zmq -import zmq.asyncio - -from vllm.config import DecodingConfig, ModelConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.disaggregated.protocol import (PDGenerationRequest, - PDGenerationResponse, PDRequestType, - PDResponseType) -from vllm.engine.protocol import EngineClient -from vllm.inputs.data import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import Device - -logger = init_logger(__name__) - -DEFAULT_MAX_TOKENS = 32000 - - -class PDController(EngineClient): - """ - Controller that schedules work on the PDWorkers. - - Conforms for the EngineClient protocol so it can - be wrapped with the OpenAI Server. - - Two Phases: - * Send request to prefill worker, await ack. - * Send request to decode worker. - - KVSync happens directly between Engines, - handled by vLLM KVCacheTransfer. - - [ OpenAI Server ] - | - [ PDController ] - | - [ zmq ] - | - [ PDWorker ] [ PDWorker ] - | - [ Engine ] <---> [ Engine ] - - After PR #12957, we will support xPyD, so we will - also need to implement a scheduler and service - discovery for the workers. - - This PDController may be implemented as a K8s - controller. This is intended to be a prototype. - - * TODO: better error handling - * TODO: support logprobs, multimodal, etc. - """ - - def __init__(self, prefill_addr: str, decode_addr: str, - connector_addr: str, model_name: str): - # Request queues. - self.queues: dict[str, asyncio.Queue] = {} - - # Serialization encoder. - self.encoder = msgspec.msgpack.Encoder() - - # ZMQ communication. - # TODO: once https://github.com/vllm-project/vllm/pull/12957 - # lands, do service discovery to scale out workers. - self.ctx = zmq.asyncio.Context() - self.to_decode = self.ctx.socket(zmq.constants.PUSH) - self.to_decode.bind(f"{decode_addr}") - self.to_prefill = self.ctx.socket(zmq.constants.PUSH) - self.to_prefill.bind(f"{prefill_addr}") - self.connector_addr = connector_addr - self.decode_addr = decode_addr - self.prefill_addr = prefill_addr - - # Background loops (started on first generate()). - self.output_handler: Optional[asyncio.Task] = None - self.log_running: Optional[asyncio.Task] = None - - # Dummy: needed for EngineClient Protocol. - # TODO: refactor OAI Server to avoid needing this. - self.model_config = ModelConfig(model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="auto", - task="generate", - seed=42) - - # Dummy: needed for EngineClient Protocol. - # TODO: refactor OAI Server to avoid needing this. - self.tokenizer = TokenizerGroup( - **dict(tokenizer_id=self.model_config.tokenizer, - enable_lora=False, - max_num_seqs=1024, - max_loras=0, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision, - truncation_side=self.model_config.truncation_side)) - - def shutdown(self): - if (ctx := self.ctx) is not None: - ctx.destroy(linger=0) - if (task := self.log_running) is not None: - task.cancel() - if (task := self.output_handler) is not None: - task.cancel() - - ipc_paths = [self.connector_addr, self.decode_addr, self.prefill_addr] - for path in ipc_paths: - socket_path = path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - async def _run_log_running(self): - logger.info("Running requests: %d", len(self.queues)) - await asyncio.sleep(10.) - - async def _run_output_handler(self): - """ - Pull responses from Decode + Prefill engines and - distribute back to the generate() tasks. - """ - decoder = msgspec.msgpack.Decoder(PDGenerationResponse) - - socket: Optional[zmq.asyncio.Socket] = None - try: - socket = self.ctx.socket(zmq.constants.PULL) - socket.bind(self.connector_addr) - - while True: - res_type, res_data = await socket.recv_multipart() - if res_type == PDResponseType.FAILURE: - raise Exception("Failure Response from PDWorker.") - elif res_type == PDResponseType.GENERATION: - response = decoder.decode(res_data) - logger.debug("Got Response: %s", response.request_id) - self.queues[response.request_id].put_nowait(response) - else: - raise Exception("Unknown response type.") - except Exception as e: - # TODO: distinguish between fatal and non-fatal errors. - for q in self.queues.values(): - q.put_nowait(e) - raise e - finally: - if socket is not None: - socket.close(linger=0) - - async def _run_prefill( - self, - request: PDGenerationRequest, - q: asyncio.Queue[Union[Exception, PDGenerationResponse]], - ): - # Send request to the prefill instance. - req_bytes = self.encoder.encode(request) - msg = (PDRequestType.GENERATION, req_bytes) - await self.to_prefill.send_multipart(msg, copy=False) - - # Await completion of the prefill. - response = await q.get() - if isinstance(response, Exception): - raise response - logger.debug("Got Decode Response: %s", request.request_id) - - async def _run_decode( - self, - request: PDGenerationRequest, - q: asyncio.Queue[Union[Exception, PDGenerationResponse]], - ) -> AsyncGenerator[PDGenerationResponse]: - # Send request to the decode instance. - req_bytes = self.encoder.encode(request) - msg = (PDRequestType.GENERATION, req_bytes) - await self.to_decode.send_multipart(msg, copy=False) - - # Iterate response queue and yield each response to caller. - finished = False - while not finished: - response = await q.get() - if isinstance(response, Exception): - raise response - logger.debug("Got Decode Response: %s", request.request_id) - finished = response.finish_reason is not None - yield response - - def _to_request_output( - self, - response: PDGenerationResponse, - prompt_token_ids: list[int], - ) -> RequestOutput: - finished = response.finish_reason is not None - return RequestOutput( - request_id=response.request_id, - prompt=None, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - outputs=[ - CompletionOutput(index=0, - text=response.text, - token_ids=response.token_ids, - cumulative_logprob=None, - logprobs=None, - finish_reason=response.finish_reason, - stop_reason=response.stop_reason) - ], - finished=finished, - ) - - async def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput]: - # Start loops on first request. - if self.output_handler is None: - self.output_handler = asyncio.create_task( - self._run_output_handler()) - self.log_running = asyncio.create_task(self._run_log_running()) - - # TODO: Expand to support the full matrix. - if "prompt_token_ids" not in prompt: - raise NotImplementedError( - "We currently only support TokensPrompt for P/D!") - if lora_request is not None: - raise NotImplementedError( - "We currently do not support LoRA for P/D!") - if trace_headers is not None: - raise NotImplementedError( - "We currently do not support tracing for P/D!") - if prompt_adapter_request is not None: - raise NotImplementedError( - "We currently do not support prompt adapter for P/D!") - if priority != 0: - raise NotImplementedError( - "We currently do not support priority for P/D!") - if request_id in self.queues: - raise ValueError(f"Found duplicate request_id: {request_id}!") - - # Queue to gather output from output_handler. - q = asyncio.Queue() - self.queues[request_id] = q - - # (1) Perform the Prefill. - original_max_tokens = sampling_params.max_tokens - request = PDGenerationRequest( - request_id=request_id, - prompt_token_ids=prompt["prompt_token_ids"], - sampling_params=sampling_params) - request.sampling_params.max_tokens = 1 - logger.debug("Sending Prefill: %s", request.request_id) - pd_response = await self._run_prefill(request, q) - - # (2) Perform the Decodes. - logger.debug("Sending Decode: %s", request.request_id) - request.sampling_params.max_tokens = original_max_tokens - async for pd_response in self._run_decode(request, q): - yield self._to_request_output(pd_response, - prompt["prompt_token_ids"]) - - async def beam_search( - self, - prompt: PromptType, - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - raise NotImplementedError - - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - raise NotImplementedError - - async def abort(self, request_id: str) -> None: - raise NotImplementedError - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def get_decoding_config(self) -> DecodingConfig: - raise NotImplementedError - - async def get_input_preprocessor(self) -> InputPreprocessor: - raise NotImplementedError - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if lora_request is not None: - raise NotImplementedError( - "LoRA is not yet supported in the PDEngine.") - return self.tokenizer.get_lora_tokenizer(None) - - async def is_tracing_enabled(self) -> bool: - return False - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[list[SamplerOutput]] = None, - ) -> None: - pass - - async def check_health(self) -> None: - pass - - async def start_profile(self) -> None: - raise NotImplementedError - - async def stop_profile(self) -> None: - raise NotImplementedError - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - raise NotImplementedError - - async def sleep(self, level: int = 1) -> None: - raise NotImplementedError - - async def wake_up(self) -> None: - raise NotImplementedError - - async def is_sleeping(self) -> bool: - return False - - async def add_lora(self, lora_request: LoRARequest) -> None: - raise NotImplementedError - - @property - def errored(self) -> bool: - return False diff --git a/vllm/disaggregated/pd_worker.py b/vllm/disaggregated/pd_worker.py index e24240d83f3df..5053ae5fdb82a 100644 --- a/vllm/disaggregated/pd_worker.py +++ b/vllm/disaggregated/pd_worker.py @@ -26,8 +26,8 @@ class PDWorker: """ PDWorker * Wrapper around AsyncLLM to handle converting PRRequests - to PDResponse and sending back to the PDClient. - * Leverages ZMQ for communication with PDClient. We may + to PDResponse and sending back to the PDConroller. + * Leverages ZMQ for communication with PDConroller. We may expand this in the future. """ # Engine. @@ -66,7 +66,7 @@ class PDWorker: async def run_busy_loop(self): """ main loop: - 1) wait for a request from the PDClient + 1) wait for a request from the PDConroller 2) handle the request """ logger.info("PDWorker is ready To handle requests.") @@ -116,7 +116,7 @@ class PDWorker: * 1) submit request to AsyncLLM * 2) iterate the RequestOutputs * 3) convert RequestOutput --> PDResponse - * 4) serialize and send to PDClient + * 4) serialize and send to PDConroller """ request_id = req.request_id @@ -131,8 +131,9 @@ class PDWorker: # 3) Convert RequestOutput --> PDResponse. response = PDGenerationResponse.from_request_output(request_output) - # 4) Serialize and send to PDClient. + # 4) Serialize and send to PDConroller. response_bytes = self.encoder.encode(response) + msg = [PDGenerationResponse.SUCCE, response_bytes] logger.debug("Sending: %s", request_id) - await self.to_client.send(response_bytes, copy=False) + await self.to_client.send_multipart(msg, copy=False) logger.debug("Sent: %s", request_id) diff --git a/vllm/disaggregated/protocol.py b/vllm/disaggregated/protocol.py index ee7dde519acf3..f82fa6d653387 100644 --- a/vllm/disaggregated/protocol.py +++ b/vllm/disaggregated/protocol.py @@ -14,8 +14,8 @@ from vllm.outputs import RequestOutput class PDRequestType: - GENERATION = b"generation" - ABORT = b"abort" + GENERATION = b'\x00' + ABORT = b'\x01' class PDGenerationRequest(msgspec.Struct): @@ -30,8 +30,8 @@ class PDAbortRequest(msgspec.Struct): class PDResponseType: - GENERATION = b"generation" - FAILURE = b"failure" + GENERATION = b'\x00' + FAILURE = b'\x01' class PDGenerationResponse(msgspec.Struct): diff --git a/vllm/entrypoints/disaggregated/connector.py b/vllm/entrypoints/disaggregated/connector.py index c098c433e08cf..42b03434c86a0 100644 --- a/vllm/entrypoints/disaggregated/connector.py +++ b/vllm/entrypoints/disaggregated/connector.py @@ -2,7 +2,7 @@ """ Toy connector for prototyping. -When PDClient supports the protocol and we clean up the +When PDConroller supports the protocol and we clean up the OpenAI Server, we can drop this in favor of vllm serve. """ @@ -14,7 +14,7 @@ import uvloop from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm.disaggregated.pd_client import PDClient +from vllm.disaggregated.pd_contoller import PDController from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, ErrorResponse) @@ -51,12 +51,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @asynccontextmanager -async def pd_client_ctx(prefill_addr: str, decode_addr: str, - connector_addr: str, - model_name: str) -> AsyncIterator[PDClient]: - client = PDClient(prefill_addr, decode_addr, connector_addr, model_name) - yield client - client.shutdown() +async def contoller_ctx(prefill_addr: str, decode_addr: str, + contoller_addr: str, + model_name: str) -> AsyncIterator[PDController]: + c = PDController(prefill_addr, decode_addr, contoller_addr, model_name) + yield c + c.shutdown() async def main(args, **uvicorn_kwargs): @@ -69,12 +69,12 @@ async def main(args, **uvicorn_kwargs): # IPC Paths. prefill_addr = f"ipc://{args.prefill_addr}" decode_addr = f"ipc://{args.decode_addr}" - connector_addr = f"ipc://{args.connector_addr}" + contoller_addr = f"ipc://{args.contoller_addr}" # Start Engine. - async with pd_client_ctx(prefill_addr=prefill_addr, + async with contoller_ctx(prefill_addr=prefill_addr, decode_addr=decode_addr, - connector_addr=connector_addr, + contoller_addr=contoller_addr, model_name=args.model) as engine_client: # Initialize App State.