From 522279ebb9a4b209eebc92a67cc21c5d58535432 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 22 Mar 2025 17:12:21 -0400 Subject: [PATCH] Stash Signed-off-by: Robert Shaw --- .../disaggregated_prefill_zmq.sh | 34 +- vllm/entrypoints/disaggregated/api_server.py | 126 -------- vllm/entrypoints/disaggregated/pd_engine.py | 296 ------------------ .../disaggregated/worker_server.py | 126 -------- 4 files changed, 22 insertions(+), 560 deletions(-) delete mode 100644 vllm/entrypoints/disaggregated/api_server.py delete mode 100644 vllm/entrypoints/disaggregated/pd_engine.py delete mode 100644 vllm/entrypoints/disaggregated/worker_server.py diff --git a/examples/online_serving/disaggregated_prefill_zmq.sh b/examples/online_serving/disaggregated_prefill_zmq.sh index 81301e8b53386..1a5532e85592a 100644 --- a/examples/online_serving/disaggregated_prefill_zmq.sh +++ b/examples/online_serving/disaggregated_prefill_zmq.sh @@ -44,18 +44,27 @@ wait_for_disagg_server() { # You can also adjust --kv-ip and --kv-port for distributed inference. +MODEL=meta-llama/Llama-3.1-8B-Instruct +CONNECTOR_ADDR=connectoripc +PREFILL_WORKER_ADDR=prefillipc +DECODE_WORKER_ADDR=prefillipc +PORT=8000 # prefilling instance, which is the KV producer -CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \ - --zmq-server-addr testipc0 \ +CUDA_VISIBLE_DEVICES=0 python3 -m vllm.entrypoints.disaggregated.worker \ + --model $MODEL \ + --connector-addr $CONNECTOR_ADDR \ + --worker-addr $PREFILL_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-transfer-config \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 & # decoding instance, which is the KV consumer -CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \ - --zmq-server-addr testipc1 \ +CUDA_VISIBLE_DEVICES=1 python3 -m vllm.entrypoints.disaggregated.worker \ + --model $MODEL \ + --connector-addr $CONNECTOR_ADDR \ + --worker-addr $DECODE_WORKER_ADDR \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-transfer-config \ @@ -63,16 +72,17 @@ CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \ # launch a proxy server that opens the service at port 8000 # the workflow of this proxy: -# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens -# to 1 -# - after the prefill vLLM finishes prefill, send the request to decode vLLM -# instance (via zmq addr testipc1) -vllm connect --port 8000 \ - --prefill-addr testipc0 \ - --decode-addr testipc1 & +# - Send req to prefill instance, wait until complete. +# - Send req to decode instance, streaming tokens. +python3 -m vllm.entrypoints.disaggregated.connector \ + --port $PORT \ + --model $MODEL \ + --connector-addr $CONNECTOR_ADDR \ + --prefill-addr $PREFILL_WORKER_ADDR \ + --decode-addr $DECODE_WORKER_ADDR # wait until prefill, decode instances and proxy are ready -wait_for_server 8000 +wait_for_server $PORT wait_for_disagg_server vllm_disagg_prefill.log wait_for_disagg_server vllm_disagg_decode.log diff --git a/vllm/entrypoints/disaggregated/api_server.py b/vllm/entrypoints/disaggregated/api_server.py deleted file mode 100644 index 2b18e594163f3..0000000000000 --- a/vllm/entrypoints/disaggregated/api_server.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import uvicorn -import uvloop - -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from fastapi import FastAPI, Request -from fastapi.responses import JSONResponse, StreamingResponse - -from vllm.entrypoints.disaggregated.pd_engine import PDEngine -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) -from vllm.entrypoints.openai.protocol import CompletionRequest -from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket -from vllm.entrypoints.openai.protocol import ( - CompletionResponse, ErrorResponse) - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.disaggregated.api_server') - -app = FastAPI() - -@app.get("/v1/models") -async def show_available_models(raw_request: Request): - handler: OpenAIServingModels = raw_request.app.state.openai_serving_models - models_ = await handler.show_available_models() - return JSONResponse(content=models_.model_dump()) - -@app.post("/v1/completions") -async def create_completion(request: CompletionRequest, raw_request: Request): - handler: OpenAIServingCompletion = raw_request.app.state.openai_serving_completion - generator = await handler.create_completion(request, raw_request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.code) - elif isinstance(generator, CompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - -@asynccontextmanager -async def pd_engine_client_ctx_manager( - model_name: str, - prefill_addr: str, - decode_addr: str, - connector_addr: str) -> AsyncIterator[PDEngine]: - engine = PDEngine(model_name, prefill_addr, decode_addr, connector_addr) - yield engine - engine.shutdown() - -async def main(args, **uvicorn_kwargs): - logger.info("vLLM Disaggregate Connector Start %s %s", args, - uvicorn_kwargs) - - # Avoid dropping requests under high concurrency. - set_ulimit() - - # IPC Paths. - # NOTE FOR DEVELOPERS: when shifting to TCP, ensure you - # are not using pickle to avoid RCE security flaw. - prefill_addr = f"ipc://{args.prefill_addr}" - decode_addr = f"ipc://{args.decode_addr}" - connector_addr = f"ipc://{args.connector_addr}" - - # Start Engine. - with pd_engine_client_ctx_manager( - args.model, prefill_addr, decode_addr, connector_addr) as engine_client: - - # Initialize App State. - model_config = await engine_client.get_model_config() - app.state.openai_serving_models = OpenAIServingModels( - engine_client=engine_client, - model_config=model_config, - base_model_paths=[BaseModelPath( - name=args.served_model_name or args.model, - model_path=args.model) - ], - ) - app.state.openai_serving_completion = OpenAIServingCompletion( - engine_client=engine_client, - model_config=model_config, - models=app.state.openai_serving_models, - request_logger=None, - ) - - # Run Server. - config = uvicorn.Config(app, host="0.0.0.0", port=args.port) - server = uvicorn.Server(config) - await server.serve() - -if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible P/D Server.") - parser.add_argument("--host", - type=str, - default="0.0.0.0", - help="The host of the HTTP server.") - parser.add_argument("--port", - type=int, - default=8001, - help="The port of the HTTP server.") - parser.add_argument("--model", - type=str, - required=True, - help="The path to the model.") - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The served name of the model.") - parser.add_argument("--connector-addr", - type=str, - required=True, - help="The zmq ipc connector address") - parser.add_argument("--prefill-addr", - type=str, - required=True, - help="The zmq ipc prefill address") - parser.add_argument("--decode-addr", - type=str, - required=True, - help="The zmq ipc decode address") - args = parser.parse_args() - uvloop.run(main(args)) diff --git a/vllm/entrypoints/disaggregated/pd_engine.py b/vllm/entrypoints/disaggregated/pd_engine.py deleted file mode 100644 index fc2d7177bf99b..0000000000000 --- a/vllm/entrypoints/disaggregated/pd_engine.py +++ /dev/null @@ -1,296 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import msgspec -import os -from collections.abc import AsyncGenerator -from typing import Dict, List, Mapping, Optional - -import uvloop -import zmq -import zmq.asyncio - -from vllm import SamplingParams -from vllm.config import DecodingConfig, ModelConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse -from vllm.inputs.data import PromptType -from vllm.inputs.preprocess import InputPreprocessor -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import Device, make_zmq_socket - -logger = init_logger(__name__) - -DEFAULT_MAX_TOKENS = 32000 - -class PDEngine: - """ - PDEngine: - Equiavlent of AsyncLLM for P/D. Assumes there is - a Prefill and Decode service already running. - - * TODO: actually handle errors and failure. - * TODO: support more than just text input. - * TODO: move under vllm/v1/engine one past prototype. - """ - - def __init__( - self, - prefill_addr: str, - decode_addr: str, - connector_addr: str, - model_name: str - ): - # Request queues. - self.queues: Dict[str, asyncio.Queue] = {} - - # Serialization encoder. - self.encoder = msgspec.msgpack.Encoder() - - # ZMQ communication.. - self.ctx = zmq.asyncio.Context() - self.to_decode = make_zmq_socket( - self.ctx, f"{decode_addr}", zmq.constants.PUSH) - self.to_prefill = make_zmq_socket( - self.ctx, f"{prefill_addr}", zmq.constants.PUSH) - self.connector_addr = connector_addr - self.decode_addr = decode_addr - self.prefill_addr = prefill_addr - - # Background loops (started on first generate()). - self.output_handler: Optional[asyncio.Task] = None - self.log_running: Optional[asyncio.Task] = None - - # Dummy: needed for EngineClient Protocol. - # TODO: refactor EngineClient to avoid needing this. - self.model_config = ModelConfig( - model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=False, - dtype="auto", - seed=42 - ) - - # Dummy: needed for EngineClient Protocol. - # TODO: refactor EngineClient to avoid needing this. - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=False, - max_num_seqs=1024, - max_loras=0, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision, - truncation_side=self.model_config.truncation_side) - self.tokenizer = TokenizerGroup(**init_kwargs) - - def shutdown(self): - if (ctx := self.ctx) is not None: - ctx.destroy(linger=0) - if (task := self.log_running) is not None: - task.cancel() - if (task := self.output_handler) is not None: - task.cancel() - - ipc_paths = [ - self.connector_addr, self.decode_addr, self.prefill_addr - ] - for path in ipc_paths: - socket_path = path.replace("ipc://", "") - if os.path.exists(socket_path): - os.remove(socket_path) - - async def _run_log_running(self): - logger.info("Running requests: %d", len(self.queues)) - await asyncio.sleep(10.) - - async def _run_output_handler(self, socket: zmq.asyncio.Socket): - """ - Pull responses from Decode + Prefill engines and - distribute back to the generate() tasks. - """ - decoder = msgspec.msgpack.Decoder(PDResponse) - - socket: Optional[zmq.asyncio.Socket] = None - try: - socket = make_zmq_socket( - self.ctx, self.connector_addr, zmq.constants.PULL) - - while True: - reponse_bytes = await socket.recv().buffer - response = decoder.decode(reponse_bytes) - self.queues[response.request_id].put_nowait(response) - except: - # TODO: actually handle failure and shutdown. - raise - finally: - if socket is not None: - socket.close(linger=0) - - async def _prefill(self, - request: PDRequest, - q: asyncio.Queue[PDResponse]) -> PDResponse: - # Send request to the prefill instance. - req_bytes = self.encoder(request) - await self.to_prefill.send(req_bytes, copy=False) - - # Wait for the prefill to be done. - response = await q.get() - assert response.request_id == request.request_id - if not response.success: - # TODO: actual error handling and shutdown. - raise Exception("Failed Prefill Request.") - - return response - - async def _decode(self, - request: PDRequest, - q: asyncio.Queue[PDResponse]) -> AsyncGenerator[PDResponse]: - # Send request to the decode instance. - req_bytes = self.encoder(request) - await self.to_decode.send(req_bytes, copy=False) - - # Iterate response queue and yield each response to caller.. - finished = False - while not finished: - response = await q.get() - if not response.success: - # TODO: actual error handling and shutdown. - raise Exception("Failed Decode Request.") - finished = response.finish_reason is not None - yield response - - async def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[PDResponse]: - # Start loops on first request. - if self.output_handler is None: - self.output_handler = asyncio.create_task(self._run_output_handler()) - self.log_running = asyncio.create_task(self._run_log_running()) - - # TODO: Expand to support the full matrix. - if not isinstance(prompt, str): - raise NotImplementedError( - "We currently only support text for P/D!") - if lora_request is not None: - raise NotImplementedError( - "We currently do not suppport LoRA for P/D!") - if trace_headers is not None: - raise NotImplementedError( - "We currently do not suppport tracing for P/D!") - if prompt_adapter_request is not None: - raise NotImplementedError( - "We currently do not suppport prompt adapter for P/D!") - if priority != 0: - raise NotImplementedError( - "We currently do not support priority for P/D!") - if request_id in self.queues: - raise ValueError(f"Found duplicate request_id: {request_id}!") - - # Queue to gather output from output_handler. - q: asyncio.Queue[PDResponse] = asyncio.Queue() - self.queues[request_id] = q - - # (1) Perform the Prefill. - original_max_tokens = sampling_params.max_tokens - request = PDRequest(request_id, prompt, sampling_params) - request.sampling_params.max_tokens = 1 - response = await self._prefill(request, q) - yield response - - # (2) Perform the Decodes. - request.sampling_params.max_tokens = original_max_tokens - async for response in self._decode(request, q): - yield response - - async def beam_search( - self, - prompt: PromptType, - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - raise NotImplementedError - - def encode( - self, - prompt: PromptType, - pooling_params: PoolingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - ) -> AsyncGenerator[PoolingRequestOutput, None]: - raise NotImplementedError - - async def abort(self, request_id: str) -> None: - raise NotImplementedError - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def get_decoding_config(self) -> DecodingConfig: - raise NotImplementedError - - async def get_input_preprocessor(self) -> InputPreprocessor: - raise NotImplementedError - - async def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, - ) -> AnyTokenizer: - if lora_request is not None: - raise NotImplementedError( - "LoRA is not yet supported in the PDEngine.") - return self.tokenizer.get_lora_tokenizer(None) - - async def is_tracing_enabled(self) -> bool: - return False - - async def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None, - ) -> None: - pass - - async def check_health(self) -> None: - pass - - async def start_profile(self) -> None: - raise NotImplementedError - - async def stop_profile(self) -> None: - raise NotImplementedError - - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: - raise NotImplementedError - - async def sleep(self, level: int = 1) -> None: - raise NotImplementedError - - async def wake_up(self) -> None: - raise NotImplementedError - - async def is_sleeping(self) -> bool: - False - - async def add_lora(self, lora_request: LoRARequest) -> None: - raise NotImplementedError diff --git a/vllm/entrypoints/disaggregated/worker_server.py b/vllm/entrypoints/disaggregated/worker_server.py deleted file mode 100644 index f403c9dbd256d..0000000000000 --- a/vllm/entrypoints/disaggregated/worker_server.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import asyncio -import msgpack -import signal -import uvloop -from typing import Optional - -import zmq -import zmq.asyncio - -from vllm.engine.async_llm_engine import AsyncEngineArgs -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.disaggregated.types import PDRequest, PDResponse -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.logger import init_logger -from vllm.utils import FlexibleArgumentParser, set_ulimit, make_zmq_socket -from vllm.version import __version__ as VLLM_VERSION - -logger = init_logger(__name__) - -async def handle_request( - request: PDRequest, - engine: EngineClient, - socket: zmq.asyncio.Socket, - encoder: msgpack.Encoder, -) -> None: - request_id = request.request_id - try: - # 1) Generate RequestOutputs. - async for request_output in engine.generate( - prompt=request.prompt_token_ids, - sampling_params=request.sampling_params, - request_id=request_id): - - assert len(request_output.outputs) == 0, "Only support N=1 right now." - out = request_output.outputs[0] - - # 2) Convert RequestOutput --> PDResponse. - response = PDResponse( - request_id=request_id, - success=True, - text=out.text, - token_ids=out.token_ids, - finish_reasons=out.finish_reason, - stop_reason=out.stop_reason, - ) - response_bytes = encoder(response) - - # 3) Send to Connector. - await socket.send(response_bytes, copy=False) - - except Exception as e: - # TODO: actual error handling. - logger.error("Exception in Worker Routine: %s request_id: %s", e, - request_id) - response = PDResponse(request_id=request_id, success=False) - response_bytes = encoder(response) - await socket.send(response, copy=False) - -async def run_server(args, engine: EngineClient): - """Get Requests and Handle Them.""" - running_requests: set[asyncio.Task] = set() - decoder = msgpack.Decoder(PDRequest) - encoder = msgpack.Encoder() - - ctx: Optional[zmq.asyncio.Context] = None - try: - # IPC Setup. - ctx = zmq.asyncio.Context() - from_connector = make_zmq_socket( - ctx, f"ipc://{args.server_addr}", zmq.constants.PULL) - to_connector = make_zmq_socket( - ctx, f"ipc://{args.connector_addr}", zmq.constants.PUSH) - - # Main Loop. - while True: - # 1) Get request from the Connector. - pd_request_bytes = await from_connector.recv().buffer - pd_request = decoder(pd_request_bytes) - - # 2) Launch a coroutine to handle the request. - task = asyncio.create_task(handle_request( - pd_request, engine, to_connector, encoder)) - running_requests.add(task) - task.add_done_callback(running_requests.discard) - - except KeyboardInterrupt: - logger.debug("Worker server loop interrupted.") - - finally: - for task in running_requests: - task.cancel() - if ctx is not None: - ctx.destroy(linger=0) - - -async def main(args) -> None: - logger.info("vLLM P/D Worker Server %s", VLLM_VERSION) - logger.info("args: %s", args) - - # Workaround to avoid footguns where uvicorn drops requests - # with too many concurrent requests active due to ulimit. - set_ulimit() - - # Interrupt on sigterm during initialization. - def signal_handler(*_) -> None: - raise KeyboardInterrupt("terminated") - signal.signal(signal.SIGTERM, signal_handler) - - async with build_async_engine_client(args) as engine: - await run_server(args, engine) - -if __name__ == "__main__": - parser = FlexibleArgumentParser() - parser.add_argument('--connector-addr', - type=str, - required=True, - help='The address of the connector.') - parser.add_argument('--worker-addr', - type=str, - required=True, - help='The address of the worker.') - AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() - uvloop.run(main(args))